Skip to content

Commit 30d036f

Browse files
committed
linting
1 parent a4ffcfd commit 30d036f

File tree

1 file changed

+20
-20
lines changed

1 file changed

+20
-20
lines changed
Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,23 @@
1-
# Copyright (c) Meta Platforms, Inc. and affiliates.
2-
# All rights reserved.
3-
#
4-
# This source code is licensed under the BSD-style license found in the
5-
# LICENSE file in the root directory of this source tree.
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
66

7-
import torch
8-
from executorch.examples.models.llama.llama_transformer import RMSNorm
7+
import torch
8+
from executorch.examples.models.llama.llama_transformer import RMSNorm
99

1010

11-
def replace_rms_norm_with_native_rms_norm(module: torch.nn.Module):
12-
for name, child in module.named_children():
13-
if isinstance(child, RMSNorm):
14-
rms_norm = torch.nn.RMSNorm(child.dim, eps=child.eps)
15-
rms_norm.weight = child.weight
16-
setattr(
17-
module,
18-
name,
19-
rms_norm,
20-
)
21-
else:
22-
replace_rms_norm_with_native_rms_norm(child)
23-
return module
11+
def replace_rms_norm_with_native_rms_norm(module: torch.nn.Module):
12+
for name, child in module.named_children():
13+
if isinstance(child, RMSNorm):
14+
rms_norm = torch.nn.RMSNorm(child.dim, eps=child.eps)
15+
rms_norm.weight = child.weight
16+
setattr(
17+
module,
18+
name,
19+
rms_norm,
20+
)
21+
else:
22+
replace_rms_norm_with_native_rms_norm(child)
23+
return module

0 commit comments

Comments
 (0)