Skip to content

Commit 0be658e

Browse files
committed
claymore!
1 parent 2a4ba44 commit 0be658e

File tree

2 files changed

+6
-4
lines changed

2 files changed

+6
-4
lines changed

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ export PACKAGE_NAME=pytorch
1010

1111
# In Lightning Studio, the `lightning` package comes pre-installed.
1212
# Uninstall it first to ensure the editable install works correctly.
13-
setup:
13+
setup: update
1414
uv pip uninstall lightning pytorch-lightning lightning-fabric || true
1515
uv pip install -r requirements.txt \
1616
-r requirements/pytorch/base.txt \

tests/tests_pytorch/strategies/test_fsdp.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,8 @@ def _assert_layer_fsdp_instance(self) -> None:
8181
param_dtype = reduce_dtype = buffer_dtype = torch.float16
8282
elif self.trainer.precision in ("bf16-true", "bf16-mixed"):
8383
param_dtype = reduce_dtype = buffer_dtype = torch.bfloat16
84+
elif self.trainer.precision == "32-true":
85+
param_dtype = reduce_dtype = buffer_dtype = torch.float32
8486
else:
8587
raise ValueError(f"Unknown precision {self.trainer.precision}")
8688

@@ -215,7 +217,7 @@ def test_strategy_sync_batchnorm(tmp_path):
215217
accelerator="gpu",
216218
devices=2,
217219
strategy="fsdp",
218-
precision="16-mixed",
220+
precision="32-true",
219221
max_epochs=1,
220222
sync_batchnorm=True,
221223
)
@@ -255,7 +257,7 @@ def training_step(self, batch, batch_idx):
255257

256258
@pytest.mark.filterwarnings("ignore::FutureWarning")
257259
@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True)
258-
@pytest.mark.parametrize("precision", ["16-mixed", pytest.param("bf16-mixed", marks=RunIf(bf16_cuda=True))])
260+
@pytest.mark.parametrize("precision", ["32-true", pytest.param("bf16-mixed", marks=RunIf(bf16_cuda=True))])
259261
@pytest.mark.parametrize("state_dict_type", ["sharded", "full"])
260262
def test_strategy_checkpoint(state_dict_type, precision, tmp_path):
261263
"""Test to ensure that checkpoint is saved correctly when using a single GPU, and all stages can be run."""
@@ -347,7 +349,7 @@ def test_checkpoint_multi_gpus(tmp_path, model, strategy, strategy_cfg):
347349
accelerator="gpu",
348350
devices=2,
349351
strategy=strategy,
350-
precision="16-mixed",
352+
precision="32-true",
351353
max_epochs=1,
352354
limit_train_batches=2,
353355
limit_val_batches=2,

0 commit comments

Comments
 (0)