Skip to content

Commit c10ef87

Browse files
authored
Fix (equalize): Fix LayerwiseActivationRotation (#1413)
1 parent 9421f56 commit c10ef87

File tree

3 files changed

+26
-12
lines changed

3 files changed

+26
-12
lines changed

src/brevitas/graph/equalize.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1774,6 +1774,17 @@ def find_module_by_name(self, model: nn.Module, regions: List[Region], prefix: s
17741774
full_name = prefix + '.' + name if prefix != '' else name
17751775
self.find_module_by_name(module, regions, full_name)
17761776

1777+
def transform_model(
1778+
self, model: nn.Module, rewriters: List[Transform], delay_rewriters: bool) -> nn.Module:
1779+
# In some circumstances, it might be useful to apply model transformations at a later moment
1780+
# The user should not be resposible for this in any case
1781+
if delay_rewriters:
1782+
return model
1783+
if is_model_offloaded_accelerate(model):
1784+
return apply_rewriters_accelerate(model, rewriters)
1785+
else:
1786+
return apply_rewriters(model, rewriters)
1787+
17771788

17781789
class GraphRotationEqualization(RotationEqualization):
17791790

@@ -2005,16 +2016,6 @@ def apply(self,
20052016
else:
20062017
return graph_model
20072018

2008-
def transform_model(self, model, rewriters, delay_rewriters):
2009-
# In some circumstances, it might be useful to apply model transformations at a later moment
2010-
# The user should not be resposible for this in any case
2011-
if delay_rewriters:
2012-
return model
2013-
if is_model_offloaded_accelerate(model):
2014-
return apply_rewriters_accelerate(model, rewriters)
2015-
else:
2016-
return apply_rewriters(model, rewriters)
2017-
20182019

20192020
@torch.no_grad()
20202021
def apply_rewriters(
@@ -2114,15 +2115,17 @@ def __init__(
21142115
self.supported_sinks = (nn.Linear)
21152116

21162117
def apply(self, model: nn.Module) -> nn.Module:
2118+
regions: List[Region] = []
2119+
rewriters: List[Transform] = []
21172120

21182121
blacklist_orphan_layers = self.blacklist_layers + self.layers_to_expand
2119-
regions: List[Region] = []
21202122
self.find_module(model, regions, blacklist_layers=blacklist_orphan_layers)
21212123
expanded_regions = []
21222124
self.find_module_by_name(model, expanded_regions)
21232125

21242126
if len(expanded_regions) > 0:
21252127
regions.extend(expanded_regions)
21262128
if len(regions) > 0:
2127-
_compute_rotations(model, regions, expansion_step=self.expansion_step)
2129+
rewriters.extend(_compute_rotations(model, regions, expansion_step=self.expansion_step))
2130+
model = self.transform_model(model, rewriters, delay_rewriters=False)
21282131
return model

src/brevitas_examples/llm/main.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -335,9 +335,11 @@ def quantize_llm(args, extra_args=None):
335335
model = eq.apply(model)
336336
remove_hooks(model)
337337
elif args.rotation == 'layerwise':
338+
model = offload_model(model)
338339
eq = LayerwiseActivationRotation(
339340
layers_to_expand=layers_to_expand, expansion_step=args.expansion_step)
340341
model = eq.apply(model)
342+
remove_hooks(model)
341343
elif args.rotation == 'fused_no_fx':
342344
fused_rotation_no_fx(model, calibration_loader, args)
343345

tests/brevitas_examples/test_llm_cases.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,14 @@ class LLMQuantLayerTypeCases:
332332
"<class 'brevitas.nn.equalized_layer.EqualizedModule'>",
333333
"model.layers.0.self_attn.q_proj.layer":
334334
"<class 'brevitas.nn.quant_linear.QuantLinear'>",},},
335+
{
336+
"model": "hf-internal-testing/tiny-random-LlamaForCausalLM",
337+
"rotation": "layerwise",
338+
"exp_layer_types": {
339+
"model.layers.0.self_attn.q_proj":
340+
"<class 'brevitas.nn.equalized_layer.RotatedModule'>",
341+
"model.layers.0.self_attn.q_proj.layer":
342+
"<class 'brevitas.nn.quant_linear.QuantLinear'>",},},
335343
{
336344
"model": "hf-internal-testing/tiny-random-MistralForCausalLM",
337345
"quantize_last_layer": True,
@@ -360,6 +368,7 @@ class LLMQuantLayerTypeCases:
360368
"mistral-fp8_fnuz",
361369
"llama-mxfp8",
362370
"llama-int8-act_equalization=layerwise",
371+
"llama-int8-rotation=layerwise",
363372
"mistral-int8-quant-last-layer",
364373
"llama-int8-svd_quant",
365374
"opt-quant-sdpa",],)

0 commit comments

Comments
 (0)