Skip to content

Commit 6aa644a

Browse files
committed
precommit fix
1 parent f24ea83 commit 6aa644a

File tree

1 file changed

+8
-10
lines changed

1 file changed

+8
-10
lines changed

tests/tests_pytorch/trainer/test_memory_leak.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
23
import psutil
34
import pytest
45
import torch
@@ -25,16 +26,13 @@ def __init__(self, size=1000, dim=1000):
2526
def __len__(self):
2627
return len(self.data)
2728

28-
def __getitem__(self, idx):
29-
return self.data[idx], self.targets[idx]
30-
3129
def __iter__(self):
3230
for i in range(len(self)):
3331
yield self[i]
3432

3533
def __getitem__(self, idx):
3634
# During prediction, return only the input tensor
37-
if hasattr(self, 'prediction_mode') and self.prediction_mode:
35+
if hasattr(self, "prediction_mode") and self.prediction_mode:
3836
return self.data[idx]
3937
return self.data[idx], self.targets[idx]
4038

@@ -66,18 +64,18 @@ def test_prediction_memory_leak(tmp_path, return_predictions):
6664
devices=1,
6765
max_epochs=1,
6866
)
69-
70-
predictions = trainer.predict(model, dataloaders=dataloader, return_predictions=return_predictions)
71-
67+
68+
trainer.predict(model, dataloaders=dataloader, return_predictions=return_predictions)
69+
7270
# Get final memory usage
7371
final_memory = get_memory_usage()
74-
72+
7573
# Calculate memory growth
7674
memory_growth = final_memory - initial_memory
77-
75+
7876
# When return_predictions=False, memory growth should be minimal
7977
if not return_predictions:
8078
assert memory_growth < 100, f"Memory growth {memory_growth}MB is too high when return_predictions=False"
8179
else:
8280
# When return_predictions=True, we expect some memory growth due to storing predictions
83-
assert memory_growth > 0, "Expected memory growth when storing predictions"
81+
assert memory_growth > 0, "Expected memory growth when storing predictions"

0 commit comments

Comments
 (0)