11import os
2+
23import psutil
34import pytest
45import torch
@@ -25,16 +26,13 @@ def __init__(self, size=1000, dim=1000):
2526 def __len__ (self ):
2627 return len (self .data )
2728
28- def __getitem__ (self , idx ):
29- return self .data [idx ], self .targets [idx ]
30-
3129 def __iter__ (self ):
3230 for i in range (len (self )):
3331 yield self [i ]
3432
3533 def __getitem__ (self , idx ):
3634 # During prediction, return only the input tensor
37- if hasattr (self , ' prediction_mode' ) and self .prediction_mode :
35+ if hasattr (self , " prediction_mode" ) and self .prediction_mode :
3836 return self .data [idx ]
3937 return self .data [idx ], self .targets [idx ]
4038
@@ -66,18 +64,18 @@ def test_prediction_memory_leak(tmp_path, return_predictions):
6664 devices = 1 ,
6765 max_epochs = 1 ,
6866 )
69-
70- predictions = trainer .predict (model , dataloaders = dataloader , return_predictions = return_predictions )
71-
67+
68+ trainer .predict (model , dataloaders = dataloader , return_predictions = return_predictions )
69+
7270 # Get final memory usage
7371 final_memory = get_memory_usage ()
74-
72+
7573 # Calculate memory growth
7674 memory_growth = final_memory - initial_memory
77-
75+
7876 # When return_predictions=False, memory growth should be minimal
7977 if not return_predictions :
8078 assert memory_growth < 100 , f"Memory growth { memory_growth } MB is too high when return_predictions=False"
8179 else :
8280 # When return_predictions=True, we expect some memory growth due to storing predictions
83- assert memory_growth > 0 , "Expected memory growth when storing predictions"
81+ assert memory_growth > 0 , "Expected memory growth when storing predictions"
0 commit comments