1919from .quantized_tensor import (
2020 restore_from_saved ,
2121 prepare_for_saving ,
22+ QuantizedTensor ,
2223)
2324
2425
@@ -255,6 +256,8 @@ def start_offload(self):
255256 Start offloading of tensors. Puts copy from GPU to CPU tasks on offload stream.
256257 Before each copy event, the offload stream waits for the event signalling that the tensor is ready to be offloaded.
257258 This event is recorded in the start_offload or push_tensor call.
259+
260+ Note: tensor_list only contains regular tensors (QuantizedTensors are decomposed in push_tensor).
258261 """
259262 self ._validate_state (func_name = "start_offload" , allowed_states = ["not_offloaded" ])
260263 self .state = "offload_started"
@@ -275,19 +278,18 @@ def start_offload(self):
275278
276279 with torch .cuda .stream (self .offload_stream ):
277280 if allocate_cpu_buffers :
278- # empty_like is defined also for QuantizedTensors
279281 offloaded_tensor = torch .empty_like (
280282 tensor , device = torch .device ("cpu" ), pin_memory = True
281283 )
282284 self .cpu_tensor_group .tensor_list .append (offloaded_tensor )
283285 else :
284- assert self .cpu_tensor_group .tensor_list [tensor_id ].shape == tensor .shape , (
286+ offloaded_tensor = self .cpu_tensor_group .tensor_list [tensor_id ]
287+ assert offloaded_tensor .shape == tensor .shape , (
285288 "CPU buffer shape does not match the offloaded tensor shape:"
286- f" { self . cpu_tensor_group . tensor_list [ tensor_id ] .shape } != { tensor .shape } "
287- " Make sure that tensor shaped do not change between"
289+ f" { offloaded_tensor .shape } != { tensor .shape } "
290+ "Make sure that tensor shapes do not change between"
288291 " iterations if retain_pinned_cpu_buffers is True."
289292 )
290- offloaded_tensor = self .cpu_tensor_group .tensor_list [tensor_id ]
291293 offloaded_tensor .copy_ (tensor , non_blocking = True )
292294
293295 # aux is a dictionary that contains auxiliary data like information which tensors were deduplicated,
@@ -318,6 +320,9 @@ def start_reload(self):
318320 """
319321 Start reloading of tensors.
320322 It allocates new tensors on GPU and puts copy from CPU tasks on offload stream.
323+
324+ Note: tensor_list only contains regular tensors (QuantizedTensors are decomposed in push_tensor
325+ and reconstructed in pop_tensor).
321326 """
322327 self ._validate_state (func_name = "start_reload" , allowed_states = ["offload_finished" ])
323328 self .state = "reload_started"
@@ -330,7 +335,6 @@ def start_reload(self):
330335 # cannot move tensors from pool of one stream to another without
331336 # calling cudaFree and cudaMalloc again.
332337
333- # empty_like is defined also for QuantizedTensors.
334338 reloaded_tensor = torch .empty_like (tensor , device = torch .device ("cuda" ))
335339 self .offload_stream .wait_stream (torch .cuda .current_stream ())
336340
@@ -347,16 +351,29 @@ def start_reload(self):
347351 self .bwd_gpu_tensor_group
348352 )
349353
350- def push_tensor (self , tensor : torch .Tensor ) -> int | torch .Tensor :
354+ def push_tensor (self , tensor : torch .Tensor ) -> int | torch .Tensor | tuple [ list , list ] :
351355 """
352356 It is called when a tensor is saved for backward pass.
353357
354358 If tensor is offloaded, returns int representing the index of the tensor in the offloaded tensor group.
355359 If tensor is not offloaded, returns the tensor itself.
360+ For QuantizedTensor, returns (list of push results for each component, tensor_objs) tuple.
356361 """
357362 self ._validate_state (func_name = "push_tensor" , allowed_states = ["not_offloaded" ])
358363
359364 if self ._check_if_offload (tensor ):
365+ # For QuantizedTensor: decompose into component tensors, push each one recursively
366+ if isinstance (tensor , QuantizedTensor ):
367+ # Make a copy because prepare_for_saving modifies the object (sets fields to None)
368+ tensor_copy = tensor .detach ()
369+ # Inline prepare_for_saving logic - QuantizedTensor is a torch.Tensor subclass,
370+ # so the generic prepare_for_saving would not call tensor.prepare_for_saving()
371+ saved_tensors , tensor_obj = tensor_copy .prepare_for_saving ()
372+ push_results = [
373+ self .push_tensor (t ) if t is not None else None for t in saved_tensors
374+ ]
375+ return (push_results , [tensor_obj ])
376+
360377 self .fwd_gpu_tensor_group .tensor_list .append (tensor )
361378 # The group is processed and offloaded at the end of the forward pass of current layer.
362379 # To enable offloading of tensors faster we use self.offload_stream and record
@@ -370,23 +387,39 @@ def push_tensor(self, tensor: torch.Tensor) -> int | torch.Tensor:
370387 return len (self .fwd_gpu_tensor_group .tensor_list ) - 1
371388 return tensor
372389
373- def pop_tensor (self , tensor_or_tensor_id : torch .Tensor | int ) -> torch .Tensor :
390+ def pop_tensor (
391+ self , tensor_or_tensor_id : torch .Tensor | int | tuple [list , list ]
392+ ) -> torch .Tensor :
374393 """
375394 It is called when a tensor is used in backward pass.
376395 Returns the tensor. If tensor was offloaded/reloaded, wait for the reload of a tensor to finish.
396+ For QuantizedTensor (tuple input), reconstructs from component tensors.
377397 """
378398 self ._validate_state (
379399 func_name = "pop_tensor" , allowed_states = ["not_offloaded" , "reload_started" ]
380400 )
381401
382- # 1. tensor not offloaded
402+ # 1. tensor not offloaded (regular tensor returned as-is from push)
383403 if isinstance (tensor_or_tensor_id , torch .Tensor ):
384404 return tensor_or_tensor_id
385- # 2. the layer was not offloaded at all
405+
406+ # 2. QuantizedTensor case: tuple of (push_results, tensor_objs)
407+ if isinstance (tensor_or_tensor_id , tuple ):
408+ push_results , tensor_objs = tensor_or_tensor_id
409+ # Recursively pop each component
410+ reloaded_tensors = [
411+ self .pop_tensor (pr ) if pr is not None else None for pr in push_results
412+ ]
413+ # Inline restore_from_saved - tensor_objs[0] is the QuantizedTensor copy
414+ tensor_obj = tensor_objs [0 ]
415+ tensor_obj .restore_from_saved (reloaded_tensors )
416+ return tensor_obj
417+
418+ # 3. Regular tensor index case
386419 if self .state == "not_offloaded" :
387420 return self .fwd_gpu_tensor_group .tensor_list [tensor_or_tensor_id ]
388421
389- # 3 . the layer was offloaded
422+ # 4 . the layer was offloaded
390423 assert self .state == "reload_started"
391424 # wait for the tensor to be reloaded
392425 torch .cuda .current_stream ().wait_event (
@@ -406,6 +439,10 @@ def _check_if_offload(self, t: torch.Tensor) -> bool:
406439 """
407440 Check if tensor needs to be offloaded.
408441 """
442+ # Only offload tensors with at least 256k elements (~1MB for float32)
443+ if t .numel () < 256 * 1024 :
444+ return False
445+
409446 if (
410447 not isinstance (t , torch .nn .Parameter )
411448 and not getattr (t , "_TE_do_not_offload" , False )
@@ -418,7 +455,6 @@ def _check_if_offload(self, t: torch.Tensor) -> bool:
418455 " this tensor will be skipped."
419456 )
420457 return False
421-
422458 return True
423459 return False
424460
@@ -488,11 +524,13 @@ def bwd_step(self, layer_num: int):
488524 self .previous_bwd_layer_id = layer_num
489525 self .current_layer_id = layer_num
490526
491- def push_tensor (self , tensor : torch .Tensor ) -> int | torch .Tensor :
527+ def push_tensor (self , tensor : torch .Tensor ) -> int | torch .Tensor | tuple [ list , list ] :
492528 """Default push tensor method"""
493529 return self .layer_states [self .num_of_fwds ].push_tensor (tensor )
494530
495- def pop_tensor (self , tensor_or_tensor_id : torch .Tensor | int ) -> torch .Tensor :
531+ def pop_tensor (
532+ self , tensor_or_tensor_id : torch .Tensor | int | tuple [list , list ]
533+ ) -> torch .Tensor :
496534 """Default pop tensor method"""
497535 return self .layer_states [self .current_layer_id ].pop_tensor (tensor_or_tensor_id )
498536
@@ -592,6 +630,12 @@ def bwd_step(self, layer_num: int):
592630 for layer in self .start_reload_map [layer_num ]:
593631 self .layer_states [layer ].start_reload ()
594632
633+ def push_tensor (self , tensor : torch .Tensor ) -> int | torch .Tensor | tuple [list , list ]:
634+ """Push tensor - skip processing if layer won't be offloaded to reduce CPU overhead."""
635+ if not self .offload_layer_map .get (self .num_of_fwds , False ):
636+ return tensor
637+ return self .layer_states [self .num_of_fwds ].push_tensor (tensor )
638+
595639
596640class ManualOffloadSynchronizer (OffloadSynchronizer ):
597641 """
0 commit comments