Skip to content

Commit dcf7738

Browse files
[Misc] disable cast_forward_inputs (#460)
1 parent b2ebaaf commit dcf7738

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

fastvideo/v1/models/loader/fsdp_load.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,10 +102,12 @@ def load_fsdp_model(
102102
output_dtype: Optional[torch.dtype] = None,
103103
) -> torch.nn.Module:
104104

105+
# NOTE(will): cast_forward_inputs=True shouldn't be needed as we are
106+
# manually casting the inputs to the model
105107
mp_policy = MixedPrecisionPolicy(param_dtype,
106108
reduce_dtype,
107109
output_dtype,
108-
cast_forward_inputs=True)
110+
cast_forward_inputs=False)
109111

110112
with set_default_dtype(default_dtype), torch.device("meta"):
111113
model = model_cls(**init_params)

0 commit comments

Comments
 (0)