Skip to content

Commit 060cc7b

Browse files
authored
fully_shard usage on RMSNorm (#577)
1 parent 6c58a5b commit 060cc7b

File tree

1 file changed

+7
-0
lines changed

1 file changed

+7
-0
lines changed

fastvideo/v1/layers/layernorm.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,14 @@ def __init__(
3737
self.weight = torch.ones(hidden_size)
3838
if self.has_weight:
3939
self.weight = nn.Parameter(self.weight)
40+
41+
42+
# if we do fully_shard(model.layer_norm), and we call layer_form.forward_native(input) instead of layer_norm(input),
43+
# we need to call model.layer_norm.register_fsdp_forward_method(model, "forward_native") to make sure fsdp2 hooks are triggered
44+
# for mixed precision and cpu offloading
4045

46+
# the even better way might be fully_shard(model.layer_norm, mp_policy=, cpu_offloading=), and call model.layer_norm(input). everything should work out of the box
47+
# because fsdp2 hooks will be triggered with model.layer_norm.__call__
4148
def forward_native(
4249
self,
4350
x: torch.Tensor,

0 commit comments

Comments
 (0)