We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent b2ebaaf commit dcf7738Copy full SHA for dcf7738
fastvideo/v1/models/loader/fsdp_load.py
@@ -102,10 +102,12 @@ def load_fsdp_model(
102
output_dtype: Optional[torch.dtype] = None,
103
) -> torch.nn.Module:
104
105
+ # NOTE(will): cast_forward_inputs=True shouldn't be needed as we are
106
+ # manually casting the inputs to the model
107
mp_policy = MixedPrecisionPolicy(param_dtype,
108
reduce_dtype,
109
output_dtype,
- cast_forward_inputs=True)
110
+ cast_forward_inputs=False)
111
112
with set_default_dtype(default_dtype), torch.device("meta"):
113
model = model_cls(**init_params)
0 commit comments