Skip to content

Commit 29b9963

Browse files
ar90ncarmocca
authored andcommitted
Fix not running test codes (#13089)
Co-authored-by: Carlos Mocholí <[email protected]>
1 parent 2acff1c commit 29b9963

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

tests/strategies/test_deepspeed_strategy.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -749,9 +749,7 @@ def test_deepspeed_multigpu_stage_3_resume_training(tmpdir):
749749
initial_trainer.fit(initial_model, datamodule=dm)
750750

751751
class TestCallback(Callback):
752-
def on_train_batch_start(
753-
self, trainer: Trainer, pl_module: LightningModule, batch: Any, batch_idx: int
754-
) -> None:
752+
def on_train_epoch_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
755753
original_deepspeed_strategy = initial_trainer.strategy
756754
current_deepspeed_strategy = trainer.strategy
757755

@@ -778,10 +776,12 @@ def on_train_batch_start(
778776
model = ModelParallelClassificationModel()
779777
trainer = Trainer(
780778
default_root_dir=tmpdir,
781-
fast_dev_run=True,
782779
strategy=DeepSpeedStrategy(stage=3),
783780
accelerator="gpu",
784781
devices=1,
782+
max_epochs=2,
783+
limit_train_batches=1,
784+
limit_val_batches=0,
785785
precision=16,
786786
callbacks=TestCallback(),
787787
enable_progress_bar=False,

0 commit comments

Comments
 (0)