Skip to content
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
31c9982
Feat (equalize): adding support for permutations
i-colbert Dec 27, 2025
eddc639
Fix (equalize): minor fixes from rebase
i-colbert Jan 31, 2026
16e309a
Docs (papers): adding example config for mixquant
i-colbert Feb 2, 2026
e0507cb
Feat (test): adding tests for rotate_permute_mode
i-colbert Feb 3, 2026
fdf3bc8
Fix (tests): fixing rotate_permute_mode test
i-colbert Feb 3, 2026
ceac637
Fix (equalize): decouple GraphPermutationEqualization from GraphRotat…
i-colbert Feb 12, 2026
612694b
Fix (equalize): registering permutation algos
i-colbert Feb 12, 2026
3debec2
Fix (permute): refactor permutations into new permute.py
i-colbert Feb 12, 2026
acc68ef
Fix (permute): passing permute_fn to GraphPermutationEqualization
i-colbert Feb 12, 2026
bbe6651
Fix (utils): typing hints
i-colbert Feb 12, 2026
73a1ef4
Fix (utils): import error from refactor
i-colbert Feb 12, 2026
7e0e86e
Fix (permute): cleaning up __all__ in permute.py
i-colbert Feb 12, 2026
3a61e03
Fix (permute): adding get_regions() call to GraphRotationEqualization
i-colbert Feb 13, 2026
e04bac9
Fix (permute): adding filter_permutations()
i-colbert Feb 13, 2026
02b9ba1
Fix (args): combine apply_permute with permute_fn
i-colbert Feb 13, 2026
e6c2692
Fix (permute): inherit from RegionalWalkMixin
i-colbert Feb 13, 2026
3ec8358
Fix (permute): inherit from GraphTransform
i-colbert Feb 13, 2026
40e03de
Fix (permute): updating rotate_permute_mode
i-colbert Feb 13, 2026
e33b778
Fix (tests): updating the tests
i-colbert Feb 13, 2026
558d880
Fix (permute): another step towards decoupling
i-colbert Feb 14, 2026
ae7b21c
Fix (equalize): removing code comment
i-colbert Feb 16, 2026
0e76a03
Fix (permute): adding extra_state_kwargs to LLM entrypoint
i-colbert Feb 16, 2026
66335da
Fix (main): initializing extra_state_kwargs once
i-colbert Feb 16, 2026
1e49d78
Fix (permute): move region check
i-colbert Feb 16, 2026
8d9cf0c
Fix (papers): remove apply_permute from YAML file
i-colbert Feb 16, 2026
411d867
Fix (graph): sharing _process_input across transforms
i-colbert Feb 16, 2026
5174989
Fix (tests): fixing tests after refactor
i-colbert Feb 17, 2026
e071905
fix tests
Giuseppe5 Feb 17, 2026
21143af
fix tests
Giuseppe5 Feb 17, 2026
7e90c34
Fix (tests): test_permute fix
i-colbert Feb 18, 2026
93aa8b0
Fix (tests): fixing errors with test_permute
i-colbert Feb 18, 2026
3a1703a
Fix (equalize): fixing code comment
i-colbert Feb 18, 2026
12c0b16
Fix (tests): remove code comment
i-colbert Feb 19, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion src/brevitas/graph/equalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
from brevitas.graph.base import Transform
from brevitas.graph.hadamard import find_closest_hadamard_number
from brevitas.graph.hadamard import get_hadK
from brevitas.graph.hadamard import is_pow2
from brevitas.graph.hadamard import matmul_hadU
from brevitas.graph.hadamard import matmul_hadU_cuda
from brevitas.graph.hadamard import random_hadamard_matrix
Expand Down Expand Up @@ -720,6 +719,18 @@ def from_module_indexes(

return cls(module, weight_axis, act_axis, indexes)

def permute(self, permute_index):
self.module.weight.data = torch.index_select(
self.module.weight.data, self.weight_axis, permute_index.to(self.module.weight.device))
if hasattr(self.module, self._bias_tensor_name):
bias = getattr(self.module, self._bias_tensor_name)
# hasattr returns true if bias=None
if bias is not None:
bias.data = torch.index_select(
self.module.bias.data,
self.weight_axis,
permute_index.to(self.module.bias.device))


class EqualizationSinkWrapper(EqualizationModuleWrapper):

Expand Down Expand Up @@ -760,6 +771,10 @@ def from_module_indexes(
weight_tensor_name = "weight"
return cls(module, weight_axis, act_axis, indexes, weight_tensor_name)

def permute(self, permute_index):
self.module.weight.data = torch.index_select(
self.module.weight.data, self.weight_axis, permute_index.to(self.module.weight.device))


# When fuse_scaling = False, the scaling parameters are instances of nn.Parameter,
# which are registered to the scaling modules (used in the parametrization of the
Expand Down Expand Up @@ -2060,6 +2075,8 @@ def find_sink(node):
end_index = head_dim if head_dim != -1 else output_weight.shape[0]
output_index = EqualizationIndexes(0, end_index, 0)

# NOTE: GraphPermutationEqualization.extract_permute_regions looks for these src and
# sink names to delineate SDPA regions
region = Region.from_dicts(
srcs={'value_sdpa': value_index},
sinks={'output_sdpa': output_index},
Expand Down Expand Up @@ -2128,6 +2145,9 @@ def apply(self,
added_regions += 1
logging.debug(f"Adding {added_regions} sink-only regions")

# Store regions for potential use by GraphPermutationEqualization
self.regions = regions

if overlap:
assert not self.use_parametrized_rotations, "Overlap between expanded and optimized region not supported"
first_set, second_set = regions, expanded_regions
Expand Down
Loading
Loading