File tree Expand file tree Collapse file tree 1 file changed +2
-2
lines changed
transformer_engine/pytorch Expand file tree Collapse file tree 1 file changed +2
-2
lines changed Original file line number Diff line number Diff line change @@ -628,7 +628,7 @@ def bwd_step(self, layer_num: int):
628628 for layer in self .start_reload_map [layer_num ]:
629629 self .layer_states [layer ].start_reload ()
630630
631- def push_tensor (self , tensor : torch .Tensor ) -> int | torch .Tensor :
631+ def push_tensor (self , tensor : torch .Tensor ) -> int | torch .Tensor | tuple [ list , list ] :
632632 """Push tensor - skip processing if layer won't be offloaded to reduce CPU overhead."""
633633 if not self .offload_layer_map .get (self .num_of_fwds , False ):
634634 return tensor
@@ -679,7 +679,7 @@ def get_cpu_offload_context(
679679 offload_weights : bool = False ,
680680 double_buffering : bool = False , # pylint: disable=unused-argument
681681 manual_synchronization : bool = False ,
682- retain_pinned_cpu_buffers : bool = True ,
682+ retain_pinned_cpu_buffers : bool = False ,
683683 offload_stream : Optional [torch .cuda .Stream ] = None ,
684684):
685685 """
You can’t perform that action at this time.
0 commit comments