@@ -1863,7 +1863,12 @@ class Check(Callback):
18631863 def on_epoch_start (self , trainer , * _ ):
18641864 assert isinstance (trainer .training_type_plugin .model , DistributedDataParallel )
18651865
1866- initial = torch .cuda .memory_allocated (0 )
1866+ def current_memory ():
1867+ # before measuring the memory force release any leftover allocations, including CUDA tensors
1868+ gc .collect ()
1869+ return torch .cuda .memory_allocated (0 )
1870+
1871+ initial = current_memory ()
18671872
18681873 model = TestModel ()
18691874 trainer_kwargs = dict (
@@ -1881,22 +1886,13 @@ def on_epoch_start(self, trainer, *_):
18811886 assert list (trainer .optimizers [0 ].state .values ())[0 ]["exp_avg_sq" ].device == torch .device ("cpu" )
18821887 assert trainer .callback_metrics ["train_loss" ].device == torch .device ("cpu" )
18831888
1884- # before measuring the memory force release any leftover allocations, including CUDA tensors
1885- gc .collect ()
1886- memory_1 = torch .cuda .memory_allocated (0 )
1887- assert memory_1 == initial
1889+ assert current_memory () <= initial
18881890
18891891 deepcopy (trainer )
18901892
1891- # before measuring the memory force release any leftover allocations, including CUDA tensors
1892- gc .collect ()
1893- memory_2 = torch .cuda .memory_allocated (0 )
1894- assert memory_2 == initial
1893+ assert current_memory () <= initial
18951894
18961895 trainer_2 = Trainer (** trainer_kwargs )
18971896 trainer_2 .fit (model )
18981897
1899- # before measuring the memory force release any leftover allocations, including CUDA tensors
1900- gc .collect ()
1901- memory_3 = torch .cuda .memory_allocated (0 )
1902- assert memory_3 == initial
1898+ assert current_memory () <= initial
0 commit comments