Skip to content

Commit 1e63163

Browse files
committed
claymore
1 parent 579362f commit 1e63163

File tree

1 file changed

+4
-16
lines changed

1 file changed

+4
-16
lines changed

tests/tests_pytorch/strategies/test_fsdp.py

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -77,15 +77,9 @@ def _assert_layer_fsdp_instance(self) -> None:
7777
assert isinstance(self.layer, FullyShardedDataParallel)
7878
assert isinstance(self.trainer.strategy.precision_plugin, FSDPPrecision)
7979

80-
if self.trainer.precision == "16-mixed":
81-
param_dtype = torch.float32
82-
reduce_dtype = buffer_dtype = torch.float16
83-
elif self.trainer.precision == "bf16-mixed":
84-
param_dtype = torch.float32
85-
reduce_dtype = buffer_dtype = torch.bfloat16
86-
elif self.trainer.precision == "16-true":
80+
if self.trainer.precision in ("16-true", "16-mixed"):
8781
param_dtype = reduce_dtype = buffer_dtype = torch.float16
88-
elif self.trainer.precision == "bf16-true":
82+
elif self.trainer.precision in ("bf16-true", "bf16-mixed"):
8983
param_dtype = reduce_dtype = buffer_dtype = torch.bfloat16
9084
else:
9185
raise ValueError(f"Unknown precision {self.trainer.precision}")
@@ -138,15 +132,9 @@ def _assert_layer_fsdp_instance(self) -> None:
138132
assert isinstance(self.layer, torch.nn.Sequential)
139133
assert isinstance(self.trainer.strategy.precision_plugin, FSDPPrecision)
140134

141-
if self.trainer.precision == "16-mixed":
142-
param_dtype = torch.float32
143-
reduce_dtype = buffer_dtype = torch.float16
144-
elif self.trainer.precision == "bf16-mixed":
145-
param_dtype = torch.float32
146-
reduce_dtype = buffer_dtype = torch.bfloat16
147-
elif self.trainer.precision == "16-true":
135+
if self.trainer.precision in ("16-true", "16-mixed"):
148136
param_dtype = reduce_dtype = buffer_dtype = torch.float16
149-
elif self.trainer.precision == "bf16-true":
137+
elif self.trainer.precision in ("bf16-true", "bf16-mixed"):
150138
param_dtype = reduce_dtype = buffer_dtype = torch.bfloat16
151139
else:
152140
raise ValueError(f"Unknown precision {self.trainer.precision}")

0 commit comments

Comments
 (0)