Skip to content

Commit e622bca

Browse files
awaelchlicarmoccapre-commit-ci[bot]
authored andcommitted
Reduce flakiness of memory test (#8651)
Co-authored-by: Carlos Mocholí <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent b845c44 commit e622bca

File tree

1 file changed

+9
-13
lines changed

1 file changed

+9
-13
lines changed

tests/trainer/test_trainer.py

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1863,7 +1863,12 @@ class Check(Callback):
18631863
def on_epoch_start(self, trainer, *_):
18641864
assert isinstance(trainer.training_type_plugin.model, DistributedDataParallel)
18651865

1866-
initial = torch.cuda.memory_allocated(0)
1866+
def current_memory():
1867+
# before measuring the memory force release any leftover allocations, including CUDA tensors
1868+
gc.collect()
1869+
return torch.cuda.memory_allocated(0)
1870+
1871+
initial = current_memory()
18671872

18681873
model = TestModel()
18691874
trainer_kwargs = dict(
@@ -1881,22 +1886,13 @@ def on_epoch_start(self, trainer, *_):
18811886
assert list(trainer.optimizers[0].state.values())[0]["exp_avg_sq"].device == torch.device("cpu")
18821887
assert trainer.callback_metrics["train_loss"].device == torch.device("cpu")
18831888

1884-
# before measuring the memory force release any leftover allocations, including CUDA tensors
1885-
gc.collect()
1886-
memory_1 = torch.cuda.memory_allocated(0)
1887-
assert memory_1 == initial
1889+
assert current_memory() <= initial
18881890

18891891
deepcopy(trainer)
18901892

1891-
# before measuring the memory force release any leftover allocations, including CUDA tensors
1892-
gc.collect()
1893-
memory_2 = torch.cuda.memory_allocated(0)
1894-
assert memory_2 == initial
1893+
assert current_memory() <= initial
18951894

18961895
trainer_2 = Trainer(**trainer_kwargs)
18971896
trainer_2.fit(model)
18981897

1899-
# before measuring the memory force release any leftover allocations, including CUDA tensors
1900-
gc.collect()
1901-
memory_3 = torch.cuda.memory_allocated(0)
1902-
assert memory_3 == initial
1898+
assert current_memory() <= initial

0 commit comments

Comments
 (0)