@@ -159,27 +159,27 @@ def _pinned_memory_tensors(self):
159159        finally :
160160            pinned_dict  =  None 
161161
162-     def  _transfer_tensor_to_device (self , tensor , source_tensor ,  current_stream = None ):
162+     def  _transfer_tensor_to_device (self , tensor , source_tensor ):
163163        tensor .data  =  source_tensor .to (self .onload_device , non_blocking = self .non_blocking )
164-         if  self .record_stream   and   current_stream   is   not   None :
165-             tensor .data .record_stream (current_stream )
164+         if  self .record_stream :
165+             tensor .data .record_stream (self . _torch_accelerator_module . current_stream () )
166166
167-     def  _process_tensors_from_modules (self , pinned_memory = None ,  current_stream = None ):
167+     def  _process_tensors_from_modules (self , pinned_memory = None ):
168168        for  group_module  in  self .modules :
169169            for  param  in  group_module .parameters ():
170170                source  =  pinned_memory [param ] if  pinned_memory  else  param .data 
171-                 self ._transfer_tensor_to_device (param , source ,  current_stream )
171+                 self ._transfer_tensor_to_device (param , source )
172172            for  buffer  in  group_module .buffers ():
173173                source  =  pinned_memory [buffer ] if  pinned_memory  else  buffer .data 
174-                 self ._transfer_tensor_to_device (buffer , source ,  current_stream )
174+                 self ._transfer_tensor_to_device (buffer , source )
175175
176176        for  param  in  self .parameters :
177177            source  =  pinned_memory [param ] if  pinned_memory  else  param .data 
178-             self ._transfer_tensor_to_device (param , source ,  current_stream )
178+             self ._transfer_tensor_to_device (param , source )
179179
180180        for  buffer  in  self .buffers :
181181            source  =  pinned_memory [buffer ] if  pinned_memory  else  buffer .data 
182-             self ._transfer_tensor_to_device (buffer , source ,  current_stream )
182+             self ._transfer_tensor_to_device (buffer , source )
183183
184184    def  _onload_from_disk (self ):
185185        if  self .stream  is  not None :
@@ -214,14 +214,12 @@ def _onload_from_memory(self):
214214            self .stream .synchronize ()
215215
216216        context  =  nullcontext () if  self .stream  is  None  else  self ._torch_accelerator_module .stream (self .stream )
217-         current_stream  =  self ._torch_accelerator_module .current_stream () if  self .record_stream  else  None 
218- 
219217        with  context :
220218            if  self .stream  is  not None :
221219                with  self ._pinned_memory_tensors () as  pinned_memory :
222-                     self ._process_tensors_from_modules (pinned_memory ,  current_stream )
220+                     self ._process_tensors_from_modules (pinned_memory )
223221            else :
224-                 self ._process_tensors_from_modules (None ,  current_stream )
222+                 self ._process_tensors_from_modules (None )
225223
226224    def  _offload_to_disk (self ):
227225        # TODO: we can potentially optimize this code path by checking if the _all_ the desired 
0 commit comments