Skip to content

Commit 579362f

Browse files
committed
update
1 parent 8a50fd0 commit 579362f

File tree

1 file changed

+2
-8
lines changed

1 file changed

+2
-8
lines changed

tests/tests_fabric/strategies/test_fsdp_integration.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -87,15 +87,9 @@ def step(self, model, batch):
8787

8888
precision = self.fabric._precision
8989
assert isinstance(precision, FSDPPrecision)
90-
if precision.precision == "16-mixed":
91-
param_dtype = torch.float32
92-
reduce_dtype = buffer_dtype = torch.float16
93-
elif precision.precision == "bf16-mixed":
94-
param_dtype = torch.float32
95-
reduce_dtype = buffer_dtype = torch.bfloat16
96-
elif precision.precision == "16-true":
90+
if precision.precision in ("16-true", "16-mixed"):
9791
param_dtype = reduce_dtype = buffer_dtype = torch.float16
98-
elif precision.precision == "bf16-true":
92+
elif precision.precision in ("bf16-true", "bf16-mixed"):
9993
param_dtype = reduce_dtype = buffer_dtype = torch.bfloat16
10094
else:
10195
raise ValueError(f"Unknown precision {precision.precision}")

0 commit comments

Comments
 (0)