Skip to content

Commit 7874cd0

Browse files
authored
[TPU] Fix test assertion error from artifacts (#19825)
1 parent e0d7ede commit 7874cd0

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

tests/tests_pytorch/accelerators/test_xla.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def test_resume_training_on_cpu(tmp_path):
5656
"""Checks if training can be resumed from a saved checkpoint on CPU."""
5757
# Train a model on TPU
5858
model = BoringModel()
59-
trainer = Trainer(max_epochs=1, accelerator="tpu", devices="auto")
59+
trainer = Trainer(max_epochs=1, accelerator="tpu", devices="auto", default_root_dir=tmp_path)
6060
trainer.fit(model)
6161

6262
if trainer.world_size != trainer.num_devices:

tests/tests_pytorch/trainer/properties/test_estimated_stepping_batches.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,6 @@ def on_train_start(self):
152152
@mock.patch.dict(os.environ, os.environ.copy(), clear=True)
153153
def test_num_stepping_batches_with_tpu_multi():
154154
"""Test stepping batches with the TPU strategy across multiple devices."""
155-
trainer = Trainer(accelerator="tpu", devices="auto", max_epochs=1)
155+
trainer = Trainer(accelerator="tpu", devices="auto", max_epochs=1, logger=False, enable_checkpointing=False)
156156
model = MultiprocessModel()
157157
trainer.fit(model)

0 commit comments

Comments
 (0)