@@ -135,9 +135,7 @@ def _pinned_memory_tensors(self):
135135        finally :
136136            pinned_dict  =  None 
137137
138-     def  _transfer_tensor_to_device (self , tensor , source_tensor = None , current_stream = None ):
139-         if  source_tensor  is  None :
140-             source_tensor  =  tensor 
138+     def  _transfer_tensor_to_device (self , tensor , source_tensor , current_stream = None ):
141139        tensor .data  =  source_tensor .to (self .onload_device , non_blocking = self .non_blocking )
142140        if  self .record_stream  and  current_stream  is  not None :
143141            tensor .data .record_stream (current_stream )
@@ -159,26 +157,6 @@ def _process_tensors_from_modules(self, pinned_memory=None, current_stream=None)
159157            source  =  pinned_memory [buffer ] if  pinned_memory  else  buffer .data 
160158            self ._transfer_tensor_to_device (buffer , source , current_stream )
161159
162-     @torch .compiler .disable () 
163-     def  onload_ (self ):
164-         torch_accelerator_module  =  (
165-             getattr (torch , torch .accelerator .current_accelerator ().type )
166-             if  hasattr (torch , "accelerator" )
167-             else  torch .cuda 
168-         )
169-         context  =  nullcontext () if  self .stream  is  None  else  torch_accelerator_module .stream (self .stream )
170-         current_stream  =  torch_accelerator_module .current_stream () if  self .record_stream  else  None 
171- 
172-         if  self .stream  is  not None :
173-             # Wait for previous Host->Device transfer to complete 
174-             self .stream .synchronize ()
175- 
176-         with  context :
177-             if  self .offload_to_disk_path :
178-                 self ._onload_from_disk (current_stream )
179-             else :
180-                 self ._onload_from_memory (current_stream )
181- 
182160    def  _onload_from_disk (self , current_stream ):
183161        if  self .stream  is  not None :
184162            loaded_cpu_tensors  =  safetensors .torch .load_file (self .safetensors_file_path , device = "cpu" )
@@ -207,6 +185,26 @@ def _onload_from_memory(self, current_stream):
207185        else :
208186            self ._process_tensors_from_modules (None , current_stream )
209187
188+     @torch .compiler .disable () 
189+     def  onload_ (self ):
190+         torch_accelerator_module  =  (
191+             getattr (torch , torch .accelerator .current_accelerator ().type )
192+             if  hasattr (torch , "accelerator" )
193+             else  torch .cuda 
194+         )
195+         context  =  nullcontext () if  self .stream  is  None  else  torch_accelerator_module .stream (self .stream )
196+         current_stream  =  torch_accelerator_module .current_stream () if  self .record_stream  else  None 
197+ 
198+         if  self .stream  is  not None :
199+             # Wait for previous Host->Device transfer to complete 
200+             self .stream .synchronize ()
201+ 
202+         with  context :
203+             if  self .offload_to_disk_path :
204+                 self ._onload_from_disk (current_stream )
205+             else :
206+                 self ._onload_from_memory (current_stream )
207+ 
210208    @torch .compiler .disable () 
211209    def  _offload_to_disk (self ):
212210        # TODO: we can potentially optimize this code path by checking if the _all_ the desired 
0 commit comments