File tree Expand file tree Collapse file tree 1 file changed +20
-20
lines changed
examples/models/llama/source_transformation Expand file tree Collapse file tree 1 file changed +20
-20
lines changed Original file line number Diff line number Diff line change 1- # Copyright (c) Meta Platforms, Inc. and affiliates.
2- # All rights reserved.
3- #
4- # This source code is licensed under the BSD-style license found in the
5- # LICENSE file in the root directory of this source tree.
1+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2+ # All rights reserved.
3+ #
4+ # This source code is licensed under the BSD-style license found in the
5+ # LICENSE file in the root directory of this source tree.
66
7- import torch
8- from executorch .examples .models .llama .llama_transformer import RMSNorm
7+ import torch
8+ from executorch .examples .models .llama .llama_transformer import RMSNorm
99
1010
11- def replace_rms_norm_with_native_rms_norm (module : torch .nn .Module ):
12- for name , child in module .named_children ():
13- if isinstance (child , RMSNorm ):
14- rms_norm = torch .nn .RMSNorm (child .dim , eps = child .eps )
15- rms_norm .weight = child .weight
16- setattr (
17- module ,
18- name ,
19- rms_norm ,
20- )
21- else :
22- replace_rms_norm_with_native_rms_norm (child )
23- return module
11+ def replace_rms_norm_with_native_rms_norm (module : torch .nn .Module ):
12+ for name , child in module .named_children ():
13+ if isinstance (child , RMSNorm ):
14+ rms_norm = torch .nn .RMSNorm (child .dim , eps = child .eps )
15+ rms_norm .weight = child .weight
16+ setattr (
17+ module ,
18+ name ,
19+ rms_norm ,
20+ )
21+ else :
22+ replace_rms_norm_with_native_rms_norm (child )
23+ return module
You can’t perform that action at this time.
0 commit comments