Skip to content

Commit c7b01a6

Browse files
committed
fixes
Signed-off-by: Pawel Gadzinski <[email protected]>
1 parent ccf54b9 commit c7b01a6

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

transformer_engine/pytorch/cpu_offload.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -628,7 +628,7 @@ def bwd_step(self, layer_num: int):
628628
for layer in self.start_reload_map[layer_num]:
629629
self.layer_states[layer].start_reload()
630630

631-
def push_tensor(self, tensor: torch.Tensor) -> int | torch.Tensor:
631+
def push_tensor(self, tensor: torch.Tensor) -> int | torch.Tensor | tuple[list, list]:
632632
"""Push tensor - skip processing if layer won't be offloaded to reduce CPU overhead."""
633633
if not self.offload_layer_map.get(self.num_of_fwds, False):
634634
return tensor
@@ -679,7 +679,7 @@ def get_cpu_offload_context(
679679
offload_weights: bool = False,
680680
double_buffering: bool = False, # pylint: disable=unused-argument
681681
manual_synchronization: bool = False,
682-
retain_pinned_cpu_buffers: bool = True,
682+
retain_pinned_cpu_buffers: bool = False,
683683
offload_stream: Optional[torch.cuda.Stream] = None,
684684
):
685685
"""

0 commit comments

Comments
 (0)