File tree Expand file tree Collapse file tree 2 files changed +24
-0
lines changed
Expand file tree Collapse file tree 2 files changed +24
-0
lines changed Original file line number Diff line number Diff line change @@ -113,6 +113,7 @@ runtime.python_library(
113113 "source_transformation/prune_vocab.py",
114114 "source_transformation/quantize.py",
115115 "source_transformation/custom_kv_cache.py",
116+ "source_transformation/rms_norm.py",
116117 "source_transformation/rope.py",
117118 "source_transformation/sdpa.py",
118119 "source_transformation/spin_quant.py",
Original file line number Diff line number Diff line change 1+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2+ # All rights reserved.
3+ #
4+ # This source code is licensed under the BSD-style license found in the
5+ # LICENSE file in the root directory of this source tree.
6+
7+ import torch
8+ from executorch .examples .models .llama .llama_transformer import RMSNorm
9+
10+
11+ def replace_rms_norm_with_native_rms_norm (module : torch .nn .Module ):
12+ for name , child in module .named_children ():
13+ if isinstance (child , RMSNorm ):
14+ rms_norm = torch .nn .RMSNorm (child .dim , eps = child .eps )
15+ rms_norm .weight = child .weight
16+ setattr (
17+ module ,
18+ name ,
19+ rms_norm ,
20+ )
21+ else :
22+ replace_rms_norm_with_native_rms_norm (child )
23+ return module
You can’t perform that action at this time.
0 commit comments