diff --git a/src/brevitas/graph/base.py b/src/brevitas/graph/base.py index 2e5bdc9b5..ecd291472 100644 --- a/src/brevitas/graph/base.py +++ b/src/brevitas/graph/base.py @@ -7,6 +7,8 @@ from inspect import getcallargs from typing import Any from typing import Dict +from typing import Optional +from typing import Tuple from typing import Type import torch @@ -21,6 +23,8 @@ from brevitas.graph.utils import * from brevitas.utils.python_utils import islambda +INPUT_NAMES = ('input', 'inp', 'query', 'hidden_states', 'x') + __all__ = [ 'Transform', 'PerInputTransform', @@ -59,6 +63,44 @@ class GraphTransform(Transform): def apply(self, graph_model: GraphModule) -> GraphModule: pass + def _process_input( + self, + module: Module, + args: tuple, + kwargs: dict, + batch_dim: int = 0, + use_inp: bool = True) -> Tuple[Optional[torch.Tensor], Optional[int]]: + """ + Process input from forward hook, handling MHA cross-attention + """ + # Check for MHA Cross attention, and if found, skip it + # When using hf/accelerate, we need to check the signature of the original forward + forward_to_check = module._old_forward if hasattr( + module, '_old_forward') else module.forward + kwargs.update(zip(forward_to_check.__code__.co_varnames[1:], args[:-1])) + + # Check for cross-attention in MHA (skip if found) + if 'query' in kwargs and 'key' in kwargs and 'value' in kwargs: + if kwargs['query'].data_ptr() != kwargs['key'].data_ptr() != kwargs['value'].data_ptr(): + return None, None + + inp_kwarg = [x for x in kwargs.keys() if x in INPUT_NAMES][0] + if use_inp: + inp = kwargs[inp_kwarg] + else: + inp = args[-1] + + # Handle case where inp is a tuple (common in forward hooks) + if isinstance(inp, tuple): + assert len(inp) == 1, "Expected single element tuple" + inp = inp[0] + + # Extra check for batch_dim using named tensors + if hasattr(inp, 'names') and 'N' in inp.names: + batch_dim = inp.names.index('N') + + return inp, batch_dim + class UntilFixedPointGraphTransform(Transform): diff --git a/src/brevitas/graph/equalize.py b/src/brevitas/graph/equalize.py index b987ad27a..0a7fb39d5 100644 --- a/src/brevitas/graph/equalize.py +++ b/src/brevitas/graph/equalize.py @@ -39,7 +39,6 @@ from brevitas.graph.base import Transform from brevitas.graph.hadamard import find_closest_hadamard_number from brevitas.graph.hadamard import get_hadK -from brevitas.graph.hadamard import is_pow2 from brevitas.graph.hadamard import matmul_hadU from brevitas.graph.hadamard import matmul_hadU_cuda from brevitas.graph.hadamard import random_hadamard_matrix @@ -48,7 +47,6 @@ from brevitas.nn import ScaledDotProductAttention from brevitas.nn.equalized_layer import EqualizedModule from brevitas.nn.equalized_layer import functional_rotate_input -from brevitas.nn.equalized_layer import INPUT_NAMES from brevitas.nn.equalized_layer import RotatedModule from brevitas.nn.quant_scale_bias import ScaleBias from brevitas.proxy import BiasQuantProxyFromInjectorBase @@ -720,6 +718,18 @@ def from_module_indexes( return cls(module, weight_axis, act_axis, indexes) + def permute(self, permute_index): + self.module.weight.data = torch.index_select( + self.module.weight.data, self.weight_axis, permute_index.to(self.module.weight.device)) + if hasattr(self.module, self._bias_tensor_name): + bias = getattr(self.module, self._bias_tensor_name) + # hasattr returns true if bias=None + if bias is not None: + bias.data = torch.index_select( + self.module.bias.data, + self.weight_axis, + permute_index.to(self.module.bias.device)) + class EqualizationSinkWrapper(EqualizationModuleWrapper): @@ -760,6 +770,10 @@ def from_module_indexes( weight_tensor_name = "weight" return cls(module, weight_axis, act_axis, indexes, weight_tensor_name) + def permute(self, permute_index): + self.module.weight.data = torch.index_select( + self.module.weight.data, self.weight_axis, permute_index.to(self.module.weight.device)) + # When fuse_scaling = False, the scaling parameters are instances of nn.Parameter, # which are registered to the scaling modules (used in the parametrization of the @@ -1286,25 +1300,11 @@ def create_mul_node(self, scale, shape, axis, batch_dim=0): return mul_factor def forward_stats_hook(self, module, *args, name, batch_dim=0, use_inp=True, **kwargs): - # Check for MHA Cross attention, and if found, skip it - # When using hf/accelerate, we need to check the signature of the original forward - forward_to_check = module._old_forward if hasattr( - module, '_old_forward') else module.forward - kwargs.update(zip(forward_to_check.__code__.co_varnames[1:], args[:-1])) - if 'query' in kwargs and 'key' in kwargs and 'value' in kwargs: - if kwargs['query'].data_ptr() != kwargs['key'].data_ptr() != kwargs['value'].data_ptr(): - self.float_act_map[name] = None - return - - input_kwarg = [x for x in kwargs.keys() if x in INPUT_NAMES][0] - if use_inp: - x = kwargs[input_kwarg] - elif not use_inp: - x = args[-1] + x, batch_dim = self._process_input(module, args, kwargs, batch_dim, use_inp) - # Extra check for batch_dim - if hasattr(x, 'names') and 'N' in x.names: - batch_dim = x.names.index('N') + if x is None: + self.float_act_map[name] = None + return self.batch_dim_act_map[name] = batch_dim @@ -1969,11 +1969,12 @@ def __init__( self.delay_rewriters = delay_rewriters self.block_rotation_dim = block_rotation_dim self.disable_block_rotation_for_fused = disable_block_rotation_for_fused + self.regions = [] if self.delay_rewriters: assert return_rewriters, "If `delay_rewriters=True`, rewriters are not applied immediately. Therefore, these must be returned, by setting `return_rewriters=True`, to be applied at a later stage." if use_parametrized_rotations: - # NOTE: When use_parametrized_rotations=False, parametrized rotations are applied. This changes the attribute __class__ + # NOTE: When use_parametrized_rotations=True, parametrized rotations are applied. This changes the attribute __class__ # of the parametrized module, e.g. to"". # Therefore, algorithms that do type checking might need to use type_before_parametrizations(module), # instead of only type(module) (see layerwise_layer_handler). Algorithms that rely on in-place modifications @@ -1984,6 +1985,10 @@ def __init__( ) self.use_parametrized_rotations = use_parametrized_rotations + def get_regions(self) -> List[Region]: + """Return the list of regions identified during graph rotation equalization.""" + return self.regions + def rotate_matmuls(self, graph_module): matmul_nodes = list(graph_module.graph.nodes) matmul_nodes = [c for c in matmul_nodes if c.name == 'matmul'] @@ -2077,7 +2082,7 @@ def find_sink(node): def apply(self, graph_model: GraphModule) -> Union[Tuple[GraphModule, List[Transform]], GraphModule]: rewriters = [] - regions = _extract_regions(graph_model, state_impl_kwargs=self.full_state_kwargs) + self.regions = _extract_regions(graph_model, state_impl_kwargs=self.full_state_kwargs) expanded_regions = [] self.find_module_by_name(graph_model, expanded_regions) @@ -2100,11 +2105,11 @@ def apply(self, if self.sdpa_regions: sdpa_regions = self.rotate_sdpa(graph_model) - regions.extend(sdpa_regions) + self.regions.extend(sdpa_regions) - logging.debug(f"Applying GraphRotationEqualization on {len(regions)} regions") + logging.debug(f"Applying GraphRotationEqualization on {len(self.regions)} regions") - for r in regions: + for r in self.regions: id_list = [id(r.name_to_module[sink_name]) for sink_name in r.sinks_names] eq_layers.update(id_list) @@ -2124,21 +2129,21 @@ def apply(self, # Layerwise have only a single sink named 'sinks0' id_sink = id(o_r.get_module_from_name('sinks0')) if id_sink not in eq_layers: - regions.append(o_r) + self.regions.append(o_r) added_regions += 1 logging.debug(f"Adding {added_regions} sink-only regions") if overlap: assert not self.use_parametrized_rotations, "Overlap between expanded and optimized region not supported" - first_set, second_set = regions, expanded_regions + first_set, second_set = self.regions, expanded_regions first_exp_step, second_exp_step = 1, self.expansion_step else: - first_set, second_set = expanded_regions, regions + first_set, second_set = expanded_regions, self.regions first_exp_step, second_exp_step = self.expansion_step, 1 if self.rotate_matmul: self.rotate_matmuls(graph_model) - if len(regions) > 0: + if len(self.regions) > 0: rewriters.extend( _compute_rotations( graph_model, diff --git a/src/brevitas/graph/permute.py b/src/brevitas/graph/permute.py new file mode 100644 index 000000000..3833b78f0 --- /dev/null +++ b/src/brevitas/graph/permute.py @@ -0,0 +1,384 @@ +# Copyright (C) 2026, Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +from functools import partial +import operator +from typing import Dict +from typing import List +from typing import Optional +from typing import Tuple +from typing import Type + +import torch +from torch.fx import GraphModule +import torch.nn as nn +from tqdm import tqdm + +from brevitas.graph.base import GraphTransform +from brevitas.graph.equalize import _channel_maxabs +from brevitas.graph.equalize import _scale_invariant_layers +from brevitas.graph.equalize import _UNSUPPORTED_OP +from brevitas.graph.equalize import find_srcs +from brevitas.graph.equalize import GraphRotationEqualization +from brevitas.graph.equalize import Region +from brevitas.graph.equalize import RegionWalkMixin +from brevitas.graph.equalize import WalkRegionState +from brevitas.graph.utils import find_node_for_module +from brevitas.nn.equalized_layer import RotatedModule +from brevitas.utils.logging import setup_logger + +logging = setup_logger(__name__) + +__all__ = ['GraphPermutationEqualization', 'rotate_permute_mode'] + +# Initialize permutation-invariant layers from scale-invariant layers +_permute_invariant_layers = list(_scale_invariant_layers) +_permute_invariant_layers.extend([torch.nn.GELU, torch.nn.SELU, torch.nn.SiLU]) + +# Try to add RMSNorm +try: + from torch.nn import RMSNorm + _permute_invariant_layers.append(RMSNorm) +except: + pass + +_permute_invariant_layers = tuple(_permute_invariant_layers) +_permute_invariant_functions = (torch.nn.functional.silu,) + +# Dictionary to store registered permutation methods +_PERMUTATION_METHODS = {} + + +def register_permutation_method(name: str): + """ + Register a permutation method. + + Args: + name: The name of the permutation method (e.g., "zigzag", "massdiff") + + Examples: + >>> @register_permutation_method("my_permute") + ... def my_permute_method(x, block_rotation_dim): + ... return torch.arange(x.shape[-1]) + """ + + def _wrapper(permute_fn): + if name in _PERMUTATION_METHODS: + logging.warning( + "The permutation method '%s' already exists and will be " + "overwritten by %s.", + name, + permute_fn.__name__, + ) + _PERMUTATION_METHODS[name] = permute_fn + return permute_fn + + return _wrapper + + +def get_permutation_method(name: str): + """Get a registered permutation method by name.""" + if name not in _PERMUTATION_METHODS: + available = list(_PERMUTATION_METHODS.keys()) + raise ValueError( + f"Permutation method '{name}' not found. " + f"Available methods: {available}") + return _PERMUTATION_METHODS[name] + + +@register_permutation_method("zigzag") +def zigzag_permute(x, block_size): + if x.shape[-1] == block_size: + return torch.arange(block_size).to(x.device) + scores = _channel_maxabs(x, dim=0) + _, indexes = torch.sort(scores, descending=True) + # Inline zigzag sort logic + indexes = indexes.view(block_size, indexes.shape[-1] // block_size) + indexes[1::2] = torch.flip(indexes[1::2], dims=[1]) + indexes = indexes.t() + indexes = indexes.flatten() + return indexes + + +@register_permutation_method("random") +def random_permute(x, block_size): + if x.shape[-1] == block_size: + return torch.arange(block_size).to(x.device) + indexes = torch.randperm(x.shape[-1]).to(x.device) + return indexes + + +@register_permutation_method("absmax") +def absmax_permute(x, block_size): + if x.shape[-1] == block_size: + return torch.arange(block_size).to(x.device) + scores = _channel_maxabs(x, dim=0) + _, indexes = torch.sort(scores, descending=True) + return indexes + + +@register_permutation_method("massdiff") +def massdiff_permute(x, block_size): + if x.shape[-1] == block_size: + return torch.arange(block_size).to(x.device) + # initialize the blocks based on absmax scores + scores = torch.abs(x).mean(dim=0) + _, indexes = torch.sort(scores, descending=True) + num_blocks = x.shape[-1] // block_size + # initialize the block norms and indexes + block_norm = torch.stack([torch.abs(x[:, i]) for i in indexes[:num_blocks]], dim=1) + block_idxs = [[i] for i in indexes[:num_blocks]] + for i in indexes[num_blocks:]: + # find the block that will have the minimum l1-norm after adding the new index + norms_after_adding = block_norm + torch.abs(x[:, i]).unsqueeze(1) + norms_after_adding = torch.mean(norms_after_adding, dim=0) + min_block = torch.argmin(norms_after_adding) + # update the block norm and indexes + block_norm[:, min_block] += torch.abs(x[:, i]) + block_idxs[min_block].append(i) + # mark block as full + if (len(block_idxs[min_block]) == block_size): + block_norm[:, min_block] = float('inf') + indexes = torch.tensor(block_idxs).flatten() + return indexes + + +class GraphPermutationEqualization(GraphTransform, RegionWalkMixin): + """ + A class for managing and applying permutations to a computational graph + """ + + def __init__( + self, + block_size: int, + permute_fn: str = 'massdiff', + extra_state_kwargs: Optional[Dict[str, Tuple[Type[nn.Module]]]] = None): + assert isinstance(block_size, int) and block_size > 1, "Error: expected an integer > 1." + assert permute_fn in _PERMUTATION_METHODS, f"Error: {permute_fn} is not registered." + + # Initialize RegionWalkMixin + mul_ops = [torch.mul, operator.mul, operator.imul, operator.__mul__, operator.__imul__] + residual_fns = [torch.add, operator.add, operator.iadd, operator.__add__, operator.__iadd__] + residual_fns.extend(mul_ops) + + base_state_kwargs = { + 'supported_srcs': (nn.Embedding, RotatedModule, nn.Linear), + 'supported_sinks': (nn.Linear, RotatedModule), + 'scale_invariant_layers': _permute_invariant_layers, + 'scale_invariant_functions': _permute_invariant_functions, + 'residual_fns': tuple(residual_fns),} + RegionWalkMixin.__init__(self, **base_state_kwargs, extra_state_kwargs=extra_state_kwargs) + + # Initialize other attributes + self.hooks = [] + self.hooked_modules = set() + self.regions = list() + self.float_act_map = dict() + self.float_act_dev = dict() + self.block_size = block_size + self.permute_fn = get_permutation_method(permute_fn) + + def setup(self, graph_model: GraphModule, regions: List[Region]) -> GraphModule: + """Extract regions and setup hooks""" + self._extract_regions(graph_model, regions) + self._setup_hooks() + return graph_model + + def forward_stats_hook(self, module, *args, name, batch_dim=0, **kwargs): + inp, batch_dim = self._process_input(module, args, kwargs, batch_dim, use_inp=True) + + if inp is None: + return + + if hasattr(inp, 'names') and 'N' in inp.names: + inp.rename_(None) + inp = inp.transpose(0, batch_dim) + + inp = inp.reshape(-1, inp.shape[-1]) # [batch_size * seq_len, dim] + if name not in self.float_act_map: + self.float_act_map[name] = [] + self.float_act_dev[name] = inp.device + self.float_act_map[name].append(inp.detach().cpu()) + + def _setup_hooks(self): + for region in self.regions: + # We assume that the entire region has a unique batch_dim + batch_dim = 0 + for name in region.srcs: + module = region.get_module_from_name(name) + if hasattr(module, 'batch_first') and not module.batch_first: + batch_dim = 1 + for name in region.sinks: + module = region.get_module_from_name(name) + if hasattr(module, 'batch_first') and not module.batch_first: + batch_dim = 1 + + for name in region.sinks_names: + module = region.get_module_from_name(name) + if module not in self.hooked_modules: + self.hooked_modules.add(module) + hook_fn = partial(self.forward_stats_hook, name=name, batch_dim=batch_dim) + h = module.register_forward_hook(hook_fn) + self.hooks.append(h) + + def _is_compatible_region(self, region: Region) -> bool: + if (region.max_shape_sinks // self.block_size > 1) and \ + (region.max_shape_sinks % self.block_size == 0): + return True + return False + + def _extract_regions(self, graph_model, regions): + """ + Extract and process permutation regions from the graph model. + """ + for region in regions: + # Check if block size is compatible with the current shape + if not self._is_compatible_region(region): + continue + + # Directly add regions that already have sources identified + if (len(region.srcs) > 0): + # Skip the SDPA regions; potential head alignment issues + if 'value_sdpa' not in region.srcs_names: + self.regions.append(region) + continue + + # Skip if equalization criteria are not met + if not region.is_valid_activation_equalization: + continue + + # Create a new state for the online region + state = WalkRegionState(**self.full_state_kwargs) + + # Add all sinks from the region to the state + for sink_name, sink_wrapper in region.sinks.items(): + module = region.get_module_from_name(sink_name) + node = find_node_for_module(graph_model, module) + assert node is not None, f"Error: node {module} not found in graph" + eq_indexes = sink_wrapper.equalization_indexes + state.add_sinks(node.target, module, eq_indexes) + find_srcs(graph_model, node, state) + + # Skip region creation if unsupported operations were encountered + if _UNSUPPORTED_OP in state.sinks: + continue + + # Create a new region with updated sources but same sinks + new_region = Region.from_dicts( + srcs=state.srcs, + sinks=state.sinks, + name_to_module=state.name_to_module, + expand_region=region.expand_region) + self.regions.append(new_region) + + @staticmethod + def permute_region(region, list_of_act_val, block_size, permute_fn, device): + """ + Apply permutation to a region by calculating permutation indexes and updating + the source and sink weights accordingly. + """ + list_of_act_val_shapes = [act_val.shape for act_val in list_of_act_val] + if len(list_of_act_val_shapes) > 0: + shape_0 = list_of_act_val_shapes[0] + if any(shape_0 != shape for shape in list_of_act_val_shapes): + return + + list_of_act_val = torch.cat(list_of_act_val, dim=0).to(device) + new_indexes = permute_fn(list_of_act_val, block_size=block_size) + + for src in region.srcs.values(): + src.permute(new_indexes) + for sink in region.sinks.values(): + sink.permute(new_indexes) + + def apply(self, graph_model: GraphModule) -> GraphModule: + """ + Apply permutations to the graph model. + """ + for region in tqdm(self.regions, "Calculating permutations..."): + # Collect all activation values for this region + list_of_act_val = [] + for name in region.sinks_names: + act_vals = self.float_act_map.pop(name) + if act_vals is None or len(act_vals) == 0: + continue + list_of_act_val.extend(act_vals) + # Calculate permutation and apply to this region + self.permute_region( + region, + list_of_act_val=list_of_act_val, + block_size=self.block_size, + permute_fn=self.permute_fn, + device=self.float_act_dev[region.sinks_names[0]]) + return graph_model + + def cleanup(self): + for h in self.hooks: + h.remove() + + +class rotate_permute_mode: + """ + Context manager for applying rotation and permutation equalization. + + Args: + model: The graph module to transform + rotation: Pre-initialized GraphRotationEqualization instance + permute_fn: Permutation method name + block_size: Block size for permutations + disable_for_fused_rotations: Whether to disable permutations for fused rotations + """ + + def __init__( + self, + model: GraphModule, + rotation: GraphRotationEqualization, + block_size: int, + permute_fn: str = 'massdiff', + disable_for_fused_rotations: bool = False, + extra_state_kwargs: Optional[Dict[str, Tuple[Type[nn.Module]]]] = None): + + assert rotation is not None and isinstance(rotation, GraphRotationEqualization), \ + "Error: expected GraphRotationEqualization instance" + assert rotation.delay_rewriters, "Error: expected rotation.delay_rewriters=True" + assert rotation.return_rewriters, "Error: expected rotation.return_rewriters=True" + assert isinstance(block_size, int) and block_size > 1, "Error: expected integer > 1" + + self.model = model + self.rotation = rotation + self.permute_fn = permute_fn + self.block_size = block_size + self.disable_for_fused_rotations = disable_for_fused_rotations + + self.permutation = GraphPermutationEqualization( + block_size=block_size, permute_fn=permute_fn, extra_state_kwargs=extra_state_kwargs) + self.rewriters = [] + + def _filter_regions(self, regions: List[Region]) -> List[Region]: + """ + Given rotation regions, filter out regions where permutations shouldn't be applied + """ + permute_regions = [] + for region in regions: + # Optionally disable permutations for fused rotations by skipping those regions + if self.disable_for_fused_rotations and (len(region.srcs) > 0): + continue + permute_regions.append(region) + return permute_regions + + def __enter__(self): + # Apply rotations and get rewriters + model, rewriters = self.rotation.apply(self.model) + self.model = model + self.rewriters = rewriters + + # Filter and setup permutation hooks based on rotation regions + regions = self.rotation.get_regions() + regions = self._filter_regions(regions) + self.model = self.permutation.setup(self.model, regions) + return self + + def __exit__(self, *args, **kwargs): + # Apply permutations and cleanup + self.model = self.permutation.apply(self.model) + self.permutation.cleanup() diff --git a/src/brevitas/graph/utils.py b/src/brevitas/graph/utils.py index a055e0f30..8503d6e66 100644 --- a/src/brevitas/graph/utils.py +++ b/src/brevitas/graph/utils.py @@ -5,6 +5,7 @@ from typing import Any from typing import Dict from typing import Iterable +from typing import Optional from typing import Tuple import torch @@ -20,6 +21,7 @@ 'replace_all_uses_except', 'signature_keys', 'is_subseq', + 'find_node_for_module', 'get_module_name_and_parent', 'set_module', 'get_module', @@ -92,6 +94,18 @@ def is_subseq(seq, subseq): return any(subseq == seq[i:len(subseq) + i] for i in range(len(seq) - len(subseq) + 1)) +def find_node_for_module(graph_model, target_module) -> Optional[Node]: + """ + Find the graph node corresponding to a module instance by matching its identity. + """ + for node in graph_model.graph.nodes: + if node.op == 'call_module': + module = get_module(graph_model, node.target) + if id(module) == id(target_module): + return node + return None + + def get_module_name_and_parent(model, fully_qualified_module_name): supermodule = model prefix_list = fully_qualified_module_name.split('.') diff --git a/src/brevitas/nn/equalized_layer.py b/src/brevitas/nn/equalized_layer.py index 9e55ec69a..f68fd344f 100644 --- a/src/brevitas/nn/equalized_layer.py +++ b/src/brevitas/nn/equalized_layer.py @@ -7,6 +7,7 @@ import torch from brevitas.common import ExportMixin +from brevitas.graph.base import INPUT_NAMES from brevitas.graph.hadamard import find_closest_hadamard_number from brevitas.graph.hadamard import get_hadK from brevitas.graph.hadamard import matmul_hadU @@ -20,8 +21,6 @@ except: fast_hadamard_transform = None -INPUT_NAMES = ['input', 'inp', 'query', 'x', 'hidden_states'] - class EqualizedModule(torch.nn.Module, LayerProtocol, ExportMixin): diff --git a/src/brevitas_examples/llm/llm_args.py b/src/brevitas_examples/llm/llm_args.py index 9d4ddc972..e10d9ad49 100644 --- a/src/brevitas_examples/llm/llm_args.py +++ b/src/brevitas_examples/llm/llm_args.py @@ -7,8 +7,6 @@ from typing import Optional from warnings import warn -import torch - from brevitas_examples.common.parse_utils import create_entrypoint_args_parser from brevitas_examples.common.parse_utils import quant_format_validator @@ -392,6 +390,13 @@ def create_args_parser() -> ArgumentParser: default=0.5, type=float, help='If activation equalization is enabled, decide what alpha to use') + parser.add_argument( + '--permute-fn', + choices=['absmax', 'massdiff', 'zigzag', 'random'], + default=None, + help= + 'Permutation function to use. If None, no permutation is applied. Works with block rotations when both are enabled.' + ) parser.add_argument( '--export-target', default=None, diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index b7cc4bb6b..abef4df38 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -22,6 +22,7 @@ from brevitas.graph.equalize import fuse_parametrizations from brevitas.graph.equalize import GraphRotationEqualization from brevitas.graph.equalize import LayerwiseActivationRotation +from brevitas.graph.permute import rotate_permute_mode from brevitas.graph.quantize import functional_quantization_mode from brevitas.graph.quantize import layerwise_quantize from brevitas.graph.utils import get_module @@ -114,12 +115,14 @@ def fused_rotation_no_fx(model, calibration_loader, args): for r in rewriters: r.apply(model) + fx_model = offload_model(fx_model) # Since we apply the rewriters to a different, non-fx model, we need only to compute them # And apply them in a second moment on the non-fx model delay_rewriters = True return_rewriters = True + extra_state_kwargs = {'scale_invariant_layers': rmsnorm_classes} eq = GraphRotationEqualization( orphan_sink=args.rotation_orphan_sink, full_rotation_method=args.rotation_mode, @@ -131,13 +134,34 @@ def fused_rotation_no_fx(model, calibration_loader, args): layers_to_expand=layers_to_expand, block_rotation_dim=args.block_rotation_dim, disable_block_rotation_for_fused=args.disable_block_rotation_for_fused, - extra_state_kwargs={'scale_invariant_layers': rmsnorm_classes}) - fx_model, rewriters = eq.apply(fx_model) - - model = offload_model(model) - rewriters = fix_rewriter(rewriters, model, 'weight') + extra_state_kwargs=extra_state_kwargs) + + if args.permute_fn is not None: + print("Applying permutations...") + with rotate_permute_mode(fx_model, + rotation=eq, + permute_fn=args.permute_fn, + block_size=args.block_rotation_dim, + disable_for_fused_rotations=args.disable_block_rotation_for_fused, + extra_state_kwargs=extra_state_kwargs) as rpm: + + # Get fx_model from the context manager + fx_model = rpm.model + # Run calibration on fx_model to collect activation statistics + with torch.no_grad(): + fx_model(**next(iter(calibration_loader))) + # Get rewriters from the context manager + rewriters = rpm.rewriters + else: + # Only rotation enabled + fx_model, rewriters = eq.apply(fx_model) - model = apply_rewriters(model, rewriters, delay_rewriters=False) + # fused_rotation_no_fx() may be called either if args.rotation == 'fused_no_fx' or args.permute_fn is not None, + # so if args.rotation == 'layerwise', we need to skip applying the rewriters here to do it later + if args.rotation != 'layerwise': + model = offload_model(model) + rewriters = fix_rewriter(rewriters, model, 'weight') + model = apply_rewriters(model, rewriters, delay_rewriters=False) def set_seed(seed): @@ -352,7 +376,15 @@ def quantize_llm(args, extra_args=None): extra_state_kwargs={'scale_invariant_layers': rmsnorm_classes}) model = eq.apply(model) remove_hooks(model) - elif args.rotation == 'layerwise': + + # Permutations are always fused. So, if we are applying them, then we go through + # the 'fused_no_fx' path to get the permutation-equivariant regions in the graph. + # If args.rotation == 'layerwise', then the rotations will not be applied in + # fused_rotation_no_fx(). Rotations will be added in the layerwise block below. + if args.rotation == 'fused_no_fx' or args.permute_fn is not None: + fused_rotation_no_fx(model, calibration_loader, args) + + if args.rotation == 'layerwise': model = offload_model(model) eq = LayerwiseActivationRotation( layers_to_expand=layers_to_expand, @@ -360,8 +392,6 @@ def quantize_llm(args, extra_args=None): block_rotation_dim=args.block_rotation_dim) model = eq.apply(model) remove_hooks(model) - elif args.rotation == 'fused_no_fx': - fused_rotation_no_fx(model, calibration_loader, args) if args.weight_equalization: print("Apply weight equalization...") @@ -600,10 +630,12 @@ def quantize_llm(args, extra_args=None): model = offload_model(model) if args.load_checkpoint: + print(f"Loading checkpoint from {args.checkpoint_name}...") remove_hooks(model) with load_quant_model_mode(model): model.load_state_dict(torch.load(args.checkpoint_name, map_location='cpu')) model = offload_model(model) + print("Checkpoint loaded.") if args.gptq and not args.load_checkpoint: print("Applying GPTQ...") diff --git a/src/brevitas_examples/papers/mixquant/README.md b/src/brevitas_examples/papers/mixquant/README.md new file mode 100644 index 000000000..58bce2113 --- /dev/null +++ b/src/brevitas_examples/papers/mixquant/README.md @@ -0,0 +1,22 @@ +# MixQuant: Pushing the Limits of Block Rotations in Post-Training Quantization + +📄 [Paper](https://arxiv.org/pdf/2601.22347) +💻 [Code](https://github.com/Xilinx/brevitas/pull/1448) + + +``` +@article{sanjeet2026mixquant, + title={MixQuant: Pushing the Limits of Block Rotations in Post-Training Quantization}, + author={Sai Sanjeet and Ian Colbert and Pablo Monteagudo-Lago and Giuseppe Franco and Yaman Umuroglu and Nicholas J. Fraser}, + year={2026}, + eprint={2601.22347}, + archivePrefix={arXiv}, + primaryClass={cs.LG}, + url={https://arxiv.org/abs/2601.22347}, +} +``` + +> [!IMPORTANT] +> These yaml files were tested with transformers==4.57.3 and lighteval==0.13.0 + +Please use https://github.com/i-colbert/brevitas/tree/mixquant/src/brevitas_examples/papers/mixquant to reproduce the experiments used for the paper. diff --git a/src/brevitas_examples/papers/mixquant/llama3-mixquant-int4.yml b/src/brevitas_examples/papers/mixquant/llama3-mixquant-int4.yml new file mode 100644 index 000000000..5bd09310f --- /dev/null +++ b/src/brevitas_examples/papers/mixquant/llama3-mixquant-int4.yml @@ -0,0 +1,35 @@ +block_rotation_dim: 32 +disable_block_rotation_for_fused: true +permute_fn: massdiff +dataset_eval_split: test +dtype: bfloat16 +eval: true +few_shot_eval: lighteval +few_shot_override_batch_size: 128 +few_shot_tasks: +- arc:challenge|0 +- arc:easy|0 +- winogrande|0 +- hellaswag_lm_eval|0 +- piqa_lm_eval|0 +few_shot_zeroshot: true +gpxq_act_order: true +gpxq_block_name: model.layers +gpxq_use_quant_activations: true +gpxq_buffer_device: same +input_bit_width: 4 +input_quant_granularity: per_row +input_scale_type: dynamic +qronos: true +qronos_alpha: 1e-3 +replace_rmsnorm: true +rotation: fused_no_fx +rotation_sdpa_regions: true +rotation_mode: had +rotation_orphan_sink: true +seed: 42 +model: meta-llama/Llama-3.2-1B-Instruct +weight_bit_width: 4 +weight_param_method: mse +weight_quant_granularity: per_channel +weight_quant_type: sym diff --git a/tests/brevitas/graph/equalization_fixtures.py b/tests/brevitas/graph/equalization_fixtures.py index 45020fcc4..d36365a0f 100644 --- a/tests/brevitas/graph/equalization_fixtures.py +++ b/tests/brevitas/graph/equalization_fixtures.py @@ -608,18 +608,12 @@ def forward(self, x): @pytest_cases.fixture -def block_residual_model(): - return functools.partial(BlockResidualModel, is_tied=False) +@pytest_cases.parametrize('is_tied', [True, False]) +def block_residual_model(is_tied): + return functools.partial(BlockResidualModel, is_tied=is_tied) -@pytest_cases.fixture -def block_residual_model_tied(): - return functools.partial(BlockResidualModel, is_tied=True) - - -list_of_rotation_fixtures = [ - "block_residual_model", - "block_residual_model_tied",] +list_of_rotation_fixtures = ["block_residual_model"] rotation_model = fixture_union( 'rotation_model', list_of_rotation_fixtures, ids=list_of_rotation_fixtures) diff --git a/tests/brevitas/graph/test_permute.py b/tests/brevitas/graph/test_permute.py new file mode 100644 index 000000000..afdd66370 --- /dev/null +++ b/tests/brevitas/graph/test_permute.py @@ -0,0 +1,169 @@ +# Copyright (C) 2026, Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +import pytest +import pytest_cases +import torch + +from brevitas.graph.equalize import GraphRotationEqualization +from brevitas.graph.permute import GraphPermutationEqualization +from brevitas.graph.permute import rotate_permute_mode +from tests.marker import requires_pt_ge + +from .equalization_fixtures import * + + +def _has_tied_parameters(model: torch.nn.Module): + """Auxiliar method to check if model has tied parameters""" + # get all model parameters and their names + all_named_parameters = { + name: param for name, param in model.named_parameters(remove_duplicate=False)} + + # get only unique named parameters + no_duplicate_named_parameters = { + name: param for name, param in model.named_parameters(remove_duplicate=True)} + + # the difference of the two sets gives us the tied parameters + tied_param_names = set(all_named_parameters.keys()) - set(no_duplicate_named_parameters.keys()) + + return len(tied_param_names) > 0 + + +def _setup_test_model(rotation_model, device='cpu'): + """ + Helper function to setup a test model. + + Returns: + tuple: (model, sample_inputs) where model is the FX-traced model on device + """ + # Instantiate model + model = rotation_model() + + # Skip tied parameters + if _has_tied_parameters(model): + pytest.skip("Skipping tests with tied parameters.") + + device = torch.device(device) + model.to(device) + + # Sample input + sample_inputs = torch.rand(size=(5, IN_FEATURES)).to(device) + + # Convert to FX graph + with torch.no_grad(): + fx_model, _ = torch._dynamo.export(model)(sample_inputs) + + return fx_model, sample_inputs + + +@requires_pt_ge('2.3.1') +@pytest_cases.parametrize('permute_fn', ['massdiff', 'zigzag', 'absmax', 'random']) +@pytest_cases.parametrize('block_size', [8, IN_FEATURES]) +@pytest_cases.parametrize('expansion_step', [0, 3]) +@pytest_cases.parametrize('disable_for_fused_rotations', [True, False]) +@pytest_cases.parametrize('orphan_sink', [True, False]) +@pytest_cases.parametrize('device', ['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu']) +def test_rotate_permute_mode( + rotation_model, + permute_fn, + block_size, + expansion_step, + disable_for_fused_rotations, + orphan_sink, + device): + """Test rotate_permute_mode context manager with various configurations.""" + # Setup model + model, sample_inputs = _setup_test_model(rotation_model, device) + model.eval() + with torch.no_grad(): + expected_output = model(sample_inputs) + + # Create rotation instance + rotation = GraphRotationEqualization( + expansion_step=expansion_step, + layers_to_expand=[], + block_rotation_dim=block_size, + orphan_sink=orphan_sink, + disable_block_rotation_for_fused=disable_for_fused_rotations, + return_rewriters=True, + delay_rewriters=True) + + # Apply rotation and permutation through context manager + with rotate_permute_mode(model, + rotation=rotation, + permute_fn=permute_fn, + block_size=block_size, + disable_for_fused_rotations=disable_for_fused_rotations) as rpm: + permute_regions = rpm.permutation.regions + permute_float_act_map = rpm.permutation.float_act_map + with torch.no_grad(): + rpm.model(sample_inputs) + # Verify activation maps were populated if regions exist + if len(permute_regions) > 0: + assert len(permute_float_act_map) > 0, \ + "Activation maps should be populated after forward pass" + if (orphan_sink or not disable_for_fused_rotations) and block_size < IN_FEATURES: + assert len(permute_regions) > 0 + + # Verify output invariance + with torch.no_grad(): + output = model(sample_inputs) + assert torch.allclose(expected_output, output, atol=ATOL), \ + "Output mismatch with combined features" + + +@requires_pt_ge('2.3.1') +@pytest_cases.parametrize('block_size', [4, 8, 16, 24, 32]) +@pytest_cases.parametrize('device', ['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu']) +def test_permute_block_size_compatibility(rotation_model, block_size, device): + """ + Test block size compatibility with different model dimensions and region filtering. + + For IN_FEATURES=24, compatible block sizes are: 2, 3, 4, 6, 8, 12 + Block size of 24 is not compatible. + Block sizes like 5, 7, 16, 32 should be incompatible and regions should be filtered. + Verify this behavior is correct. + """ + # Setup model + model, sample_inputs = _setup_test_model(rotation_model, device) + model.eval() + with torch.no_grad(): + expected_output = model(sample_inputs) + + # Apply rotation to get regions + rotation = GraphRotationEqualization( + expansion_step=0, + layers_to_expand=[], + block_rotation_dim=block_size, + disable_block_rotation_for_fused=False, + return_rewriters=True, + delay_rewriters=True) + + model, rewriters = rotation.apply(model) + regions = rotation.get_regions() + + # Setup permutation - this should handle incompatible block sizes gracefully + permutation = GraphPermutationEqualization(block_size=block_size, permute_fn='massdiff') + + model = permutation.setup(model, regions) + + if block_size in [16, 24, 32]: + assert len(permutation.regions) == 0 + + # Verify that SDPA regions are filtered (regions with 'value_sdpa' in source names) + for region in permutation.regions: + assert 'value_sdpa' not in region.srcs_names, \ + "SDPA regions should be filtered out" + + # Run model to collect statistics and apply permutations + with torch.no_grad(): + model(sample_inputs) + + model = permutation.apply(model) + permutation.cleanup() + + # Verify output invariance + with torch.no_grad(): + output = model(sample_inputs) + assert torch.allclose(expected_output, output, atol=ATOL), \ + "Output changed after permutation - invariance violated"