Skip to content

Commit 26798a4

Browse files
authored
Feat (brevitas_examples/llm): better RMSNorm replacement (#1436)
1 parent a9bb59e commit 26798a4

File tree

4 files changed

+147
-78
lines changed

4 files changed

+147
-78
lines changed

src/brevitas/graph/base.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -435,18 +435,18 @@ class ModuleToModuleByClass(ModuleToModule):
435435
def __init__(self, old_module_class, new_module_class, **kwargs):
436436
super().__init__(new_module_class, **kwargs)
437437
self.old_module_class = old_module_class
438+
self.old_new_module_dict = {}
438439

439440
def apply(self, model: GraphModule) -> GraphModule:
440-
old_new_module_dict = {}
441441
for old_module in model.modules():
442442
# check for equality, not inheritance
443443
if type(old_module) == self.old_module_class:
444444
# init the new module based on the old one
445445
new_module = self.init_new_module(old_module)
446446
# register modules pair to be replaced
447-
old_new_module_dict[old_module] = new_module
447+
self.old_new_module_dict[old_module] = new_module
448448
# replace all pairs registered
449-
for old_module, new_module in old_new_module_dict.items():
449+
for old_module, new_module in self.old_new_module_dict.items():
450450
replace_module(model, old_module, new_module)
451451
return model
452452

src/brevitas/graph/equalize.py

Lines changed: 84 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,7 @@ class WalkRegionState:
293293

294294
supported_srcs: set = _supported_layers
295295
supported_sinks: set = _supported_layers
296-
scale_invariant_function: set = _scale_invariant_op
296+
scale_invariant_functions: set = _scale_invariant_op
297297
scale_invariant_layers: set = _scale_invariant_layers
298298
residual_fns: set = _residual_fns
299299
residual_methods: set = _residual_methods
@@ -1027,7 +1027,7 @@ def find_srcs_channel_dim(state, model, inp_node):
10271027
return total_channels
10281028
elif _is_scale_invariant_module(model, inp_node,
10291029
state.scale_invariant_layers) or _is_scale_invariant_function(
1030-
inp_node, state.scale_invariant_function):
1030+
inp_node, state.scale_invariant_functions):
10311031
return find_srcs_channel_dim(state, model, inp_node.all_input_nodes[0])
10321032
else:
10331033
return _UNSUPPORTED_OP
@@ -1078,7 +1078,7 @@ def find_srcs(graph_model: GraphModule, starting_node: Node,
10781078
0]
10791079
elif _is_scale_invariant_module(
10801080
graph_model, node, state.scale_invariant_layers) or _is_scale_invariant_function(
1081-
node, state.scale_invariant_function):
1081+
node, state.scale_invariant_functions):
10821082
find_sinks(graph_model, node, state)
10831083
find_srcs(graph_model, node, state)
10841084
elif _is_add(node, state.residual_fns, state.residual_methods):
@@ -1126,7 +1126,7 @@ def find_sinks(graph_model: GraphModule, starting_node: Node,
11261126

11271127
elif _is_scale_invariant_module(
11281128
graph_model, node, state.scale_invariant_layers) or _is_scale_invariant_function(
1129-
node, state.scale_invariant_function):
1129+
node, state.scale_invariant_functions):
11301130
find_sinks(graph_model, node, state)
11311131
elif _is_add(node, state.residual_fns, state.residual_methods):
11321132
state.update_offset = False
@@ -1785,10 +1785,43 @@ def _merge_ln(layer_norm, next_module, scale_bias_by_weight):
17851785
_replace_bias(next_module, new_bias)
17861786

17871787

1788+
class RegionWalkMixin:
1789+
1790+
def __init__(
1791+
self,
1792+
supported_srcs: Tuple[Type[nn.Module]] = _supported_layers,
1793+
supported_sinks: Tuple[Type[nn.Module]] = _supported_layers,
1794+
scale_invariant_layers: Tuple[Type[nn.Module]] = _scale_invariant_layers,
1795+
scale_invariant_functions: Tuple[Callable] = _scale_invariant_op,
1796+
residual_fns: Tuple[Callable] = _residual_fns,
1797+
residual_methods: Tuple[str] = _residual_methods,
1798+
extra_state_kwargs: Optional[Dict[str, Tuple[Type[nn.Module]]]] = None):
1799+
self.supported_srcs = supported_srcs
1800+
self.supported_sinks = supported_sinks
1801+
self.scale_invariant_layers = scale_invariant_layers
1802+
self.scale_invariant_functions = scale_invariant_functions
1803+
self.residual_fns = residual_fns
1804+
self.residual_methods = residual_methods
1805+
1806+
if extra_state_kwargs is not None:
1807+
for attr_name, value in extra_state_kwargs.items():
1808+
combined_value = value + getattr(self, attr_name)
1809+
setattr(self, attr_name, combined_value)
1810+
1811+
@property
1812+
def full_state_kwargs(self) -> Dict[str, Tuple[Type[nn.Module]]]:
1813+
return {
1814+
'supported_srcs': self.supported_srcs,
1815+
'supported_sinks': self.supported_sinks,
1816+
'scale_invariant_layers': self.scale_invariant_layers,
1817+
'scale_invariant_functions': self.scale_invariant_functions,
1818+
'residual_fns': self.residual_fns,
1819+
'residual_methods': self.residual_methods}
1820+
1821+
17881822
class RotationEqualization(GraphTransform):
17891823

17901824
def __init__(self, blacklist_layers, layers_to_expand) -> None:
1791-
super(RotationEqualization, self).__init__()
17921825
if blacklist_layers is not None:
17931826
self.blacklist_layers = blacklist_layers
17941827
else:
@@ -1797,19 +1830,19 @@ def __init__(self, blacklist_layers, layers_to_expand) -> None:
17971830
self.layers_to_expand = layers_to_expand
17981831
else:
17991832
self.layers_to_expand = []
1800-
self.supported_sinks = ()
18011833

18021834
def find_module(
18031835
self,
18041836
model: nn.Module,
18051837
regions: List[Region],
1838+
supported_sinks: Tuple[nn.Module],
18061839
prefix: str = '',
18071840
blacklist_layers: Optional[List[str]] = None):
18081841
"""
18091842
Iterate through the model looking at immediate children of every module to look for supported modules.
18101843
This allows us to stop the search when we meet a top-level module that is supported.
18111844
"""
1812-
if isinstance(model, self.supported_sinks):
1845+
if isinstance(model, supported_sinks):
18131846
if prefix in blacklist_layers:
18141847
return
18151848
weight = get_weight_sink(model)
@@ -1820,7 +1853,7 @@ def find_module(
18201853
else:
18211854
for name, module in model.named_children():
18221855
full_name = prefix + '.' + name if prefix != '' else name
1823-
self.find_module(module, regions, full_name, blacklist_layers)
1856+
self.find_module(module, regions, supported_sinks, full_name, blacklist_layers)
18241857

18251858
def find_module_by_name(self, model: nn.Module, regions: List[Region], prefix: str = ''):
18261859
"""
@@ -1852,7 +1885,7 @@ def transform_model(
18521885
return apply_rewriters(model, rewriters)
18531886

18541887

1855-
class GraphRotationEqualization(RotationEqualization):
1888+
class GraphRotationEqualization(RotationEqualization, RegionWalkMixin):
18561889

18571890
def __init__(
18581891
self,
@@ -1866,16 +1899,20 @@ def __init__(
18661899
layers_to_expand: Optional[List[str]] = None,
18671900
expansion_step: int = None,
18681901
delay_rewriters: bool = False,
1869-
return_rewriters: bool = False) -> None:
1870-
super(GraphRotationEqualization, self).__init__(blacklist_layers, layers_to_expand)
1902+
return_rewriters: bool = False,
1903+
extra_state_kwargs: Optional[Dict[str, Tuple]] = None) -> None:
1904+
RotationEqualization.__init__(self, blacklist_layers, layers_to_expand)
18711905

1872-
self.supported_srcs = (nn.Linear, nn.Embedding)
1873-
self.supported_sinks = (nn.Linear)
18741906
common_scale_invariant = list(_scale_invariant_layers)
18751907
common_scale_invariant.remove(torch.nn.ReLU)
18761908
common_scale_invariant.remove(torch.nn.LeakyReLU)
1877-
self.scale_invariant_layers = tuple(common_scale_invariant) + (RMSNorm,)
1878-
self.scale_invariant_function = ()
1909+
base_state_kwargs = {
1910+
'supported_srcs': (nn.Linear, nn.Embedding),
1911+
'supported_sinks': (nn.Linear,),
1912+
'scale_invariant_layers': tuple(common_scale_invariant) + (RMSNorm,),
1913+
'scale_invariant_functions': ()}
1914+
RegionWalkMixin.__init__(self, **base_state_kwargs, extra_state_kwargs=extra_state_kwargs)
1915+
18791916
self.orphan_sink = orphan_sink
18801917
self.rotate_matmul = rotate_matmul
18811918
self.full_rotation_method = full_rotation_method
@@ -1992,13 +2029,7 @@ def find_sink(node):
19922029
def apply(self,
19932030
graph_model: GraphModule) -> Union[Tuple[GraphModule, List[Transform]], GraphModule]:
19942031
rewriters = []
1995-
regions = _extract_regions(
1996-
graph_model,
1997-
state_impl_kwargs={
1998-
'supported_srcs': self.supported_srcs,
1999-
'supported_sinks': self.supported_sinks,
2000-
'scale_invariant_layers': self.scale_invariant_layers,
2001-
'scale_invariant_function': self.scale_invariant_function})
2032+
regions = _extract_regions(graph_model, state_impl_kwargs=self.full_state_kwargs)
20022033

20032034
expanded_regions = []
20042035
self.find_module_by_name(graph_model, expanded_regions)
@@ -2007,7 +2038,11 @@ def apply(self,
20072038

20082039
if self.orphan_sink:
20092040
blacklist_orphan_layers = self.blacklist_layers + self.layers_to_expand
2010-
self.find_module(graph_model, orphan_regions, blacklist_layers=blacklist_orphan_layers)
2041+
self.find_module(
2042+
graph_model,
2043+
orphan_regions,
2044+
self.full_state_kwargs['supported_sinks'],
2045+
blacklist_layers=blacklist_orphan_layers)
20112046

20122047
if len(expanded_regions) > 0:
20132048
parameter_number_pre = 0
@@ -2095,20 +2130,23 @@ def apply_rewriters(
20952130
return model
20962131

20972132

2098-
class LayerNormToRMS(GraphTransform):
2133+
class LayerNormToRMS(GraphTransform, RegionWalkMixin):
2134+
2135+
def __init__(
2136+
self,
2137+
return_rewriters: bool = False,
2138+
extra_state_kwargs: Optional[Dict[str, Tuple]] = None) -> None:
2139+
GraphTransform.__init__(self)
2140+
2141+
base_state_kwargs = {
2142+
'supported_srcs': (nn.Linear, nn.Embedding), 'supported_sinks': (nn.LayerNorm,)}
2143+
RegionWalkMixin.__init__(self, **base_state_kwargs, extra_state_kwargs=extra_state_kwargs)
20992144

2100-
def __init__(self, return_rewriters=False) -> None:
2101-
super(LayerNormToRMS, self).__init__()
2102-
self.supported_srcs = (nn.Linear, nn.Embedding)
2103-
self.supported_sinks = (nn.LayerNorm)
21042145
self.return_rewriters = return_rewriters
21052146
assert RMSNorm is not object, 'Update your Pytorch version to 2.4+'
21062147

21072148
def apply(self, graph_model: GraphModule) -> GraphModule:
2108-
regions = _extract_regions(
2109-
graph_model,
2110-
state_impl_kwargs={
2111-
'supported_srcs': self.supported_srcs, 'supported_sinks': self.supported_sinks})
2149+
regions = _extract_regions(graph_model, state_impl_kwargs=self.full_state_kwargs)
21122150

21132151
rewriters = []
21142152
if len(regions) > 0:
@@ -2141,18 +2179,17 @@ def apply(self, graph_model: GraphModule) -> GraphModule:
21412179
return graph_model
21422180

21432181

2144-
class MergeLnAffine(GraphTransform):
2182+
class MergeLnAffine(GraphTransform, RegionWalkMixin):
21452183

2146-
def __init__(self) -> None:
2147-
super(MergeLnAffine, self).__init__()
2184+
def __init__(self, extra_state_kwargs: Optional[Dict[str, Tuple]] = None) -> None:
2185+
GraphTransform.__init__(self)
21482186
self.supported_srcs = (RMSNorm, nn.LayerNorm)
2149-
self.supported_sinks = (nn.Linear)
2187+
base_state_kwargs = {
2188+
'supported_srcs': (RMSNorm, nn.LayerNorm), 'supported_sinks': (nn.Linear,)}
2189+
RegionWalkMixin.__init__(self, **base_state_kwargs, extra_state_kwargs=extra_state_kwargs)
21502190

21512191
def apply(self, graph_model: GraphModule) -> GraphModule:
2152-
regions = _extract_regions(
2153-
graph_model,
2154-
state_impl_kwargs={
2155-
'supported_srcs': self.supported_srcs, 'supported_sinks': self.supported_sinks})
2192+
regions = _extract_regions(graph_model, state_impl_kwargs=self.full_state_kwargs)
21562193

21572194
if len(regions) > 0:
21582195
scaled_biases = set()
@@ -2180,18 +2217,21 @@ def __init__(
21802217
blacklist_layer: Optional[List] = None,
21812218
layers_to_expand: Optional[List] = None,
21822219
expansion_step: int = 0,
2183-
block_rotation_dim: Optional[int] = None):
2184-
super().__init__(blacklist_layer, layers_to_expand)
2220+
block_rotation_dim: Optional[int] = None,
2221+
extra_state_kwargs: Optional[Dict[str, Tuple]] = None):
2222+
2223+
RotationEqualization.__init__(self, blacklist_layer, layers_to_expand)
21852224
self.expansion_step = expansion_step
2186-
self.supported_sinks = (nn.Linear)
21872225
self.block_rotation_dim = block_rotation_dim
2226+
self.supported_sinks = (nn.Linear,)
21882227

21892228
def apply(self, model: nn.Module) -> nn.Module:
21902229
regions: List[Region] = []
21912230
rewriters: List[Transform] = []
21922231

21932232
blacklist_orphan_layers = self.blacklist_layers + self.layers_to_expand
2194-
self.find_module(model, regions, blacklist_layers=blacklist_orphan_layers)
2233+
self.find_module(
2234+
model, regions, self.supported_sinks, blacklist_layers=blacklist_orphan_layers)
21952235
expanded_regions = []
21962236
self.find_module_by_name(model, expanded_regions)
21972237

src/brevitas_examples/llm/llm_quant/ln_affine_merge.py

Lines changed: 48 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -3,36 +3,63 @@
33
SPDX-License-Identifier: MIT
44
"""
55

6+
from inspect import signature
7+
68
from packaging import version
79
import torch
810
from torch import nn
911

1012
from brevitas import torch_version
11-
from brevitas.graph.base import ModuleToModuleByClass
13+
from brevitas.graph import ModuleInstanceToModuleInstance
14+
from brevitas.graph import ModuleToModuleByClass
1215
from brevitas.graph.equalize import _is_scale_invariant_module
1316
from brevitas.graph.equalize import LayerNormToRMS
1417
from brevitas.graph.equalize import MergeLnAffine
1518
from brevitas.graph.utils import get_module
1619

1720

18-
def replace_rmsnorm_with_torch(model, config):
19-
assert torch_version >= version.parse('2.4'), "torch.nn.RMSNorm requires torch 2.4 or greater"
20-
set_of_layers = set(type(x) for x in model.modules() if 'RMS' in type(x).__name__)
21-
dtype = next(model.parameters()).dtype
22-
device = next(model.parameters()).device
23-
rewriters = [
24-
ModuleToModuleByClass(
25-
rms_cls,
26-
torch.nn.RMSNorm,
27-
normalized_shape=lambda module: module.weight.shape[0],
28-
eps=config.rms_norm_eps,
29-
dtype=dtype,
30-
device=device) for rms_cls in set_of_layers]
31-
dtype = next(iter(model.parameters())).dtype
32-
for r in rewriters:
33-
model = r.apply(model)
34-
model = model.to(dtype)
35-
return model
21+
class rmsnorm_patch:
22+
23+
def __init__(self, model, config, enabled=True):
24+
self.model = model
25+
self.config = config
26+
if enabled:
27+
self.rmsnorm_classes = tuple(
28+
set(type(x) for x in model.modules() if 'RMS' in type(x).__name__))
29+
else:
30+
self.rmsnorm_classes = tuple()
31+
self.mapping = dict()
32+
33+
def __enter__(self):
34+
assert torch_version >= version.parse('2.4'), "torch.nn.RMSNorm requires torch 2.4 or greater"
35+
36+
dtype = next(self.model.parameters()).dtype
37+
device = next(self.model.parameters()).device
38+
39+
rewriters = [
40+
ModuleToModuleByClass(
41+
rms_cls,
42+
torch.nn.RMSNorm,
43+
normalized_shape=lambda module: module.weight.shape[0],
44+
eps=self.config.rms_norm_eps,
45+
dtype=dtype,
46+
device=device) for rms_cls in self.rmsnorm_classes]
47+
48+
for r in rewriters:
49+
self.model = r.apply(self.model)
50+
self.mapping.update(r.old_new_module_dict)
51+
52+
self.model = self.model.to(dtype)
53+
return self
54+
55+
def __exit__(self, *args, **kwargs):
56+
dtype = next(self.model.parameters()).dtype
57+
58+
for old_module, new_module in self.mapping.items():
59+
rewriter = ModuleInstanceToModuleInstance(old_module, new_module)
60+
self.model = rewriter.apply(self.model)
61+
62+
self.model = self.model.to(dtype)
3663

3764

3865
def replace_bias(next_module, new_bias):
@@ -106,8 +133,8 @@ def merge_layernorm_affine_params(graph_model):
106133

107134

108135
@torch.no_grad()
109-
def apply_layernorm_affine_merge(graph_model):
110-
eq = MergeLnAffine()
136+
def apply_layernorm_affine_merge(graph_model, rmsnorm_classes):
137+
eq = MergeLnAffine(extra_state_kwargs={'scale_invariant_layers': rmsnorm_classes})
111138
graph_model = eq.apply(graph_model)
112139
return graph_model
113140

0 commit comments

Comments
 (0)