diff --git a/tests/pytorch/test_cpu_offloading.py b/tests/pytorch/test_cpu_offloading.py index 385998a8c5..0b4e3014f7 100644 --- a/tests/pytorch/test_cpu_offloading.py +++ b/tests/pytorch/test_cpu_offloading.py @@ -54,9 +54,11 @@ class Utils: + # Tensor big engough that both data and scaling factor tensor are bigger than 256 * 1024 elements, + # so that they are offloaded to GPU. tensor1 = torch.randn((1024, 1024), device="cuda", dtype=torch.bfloat16) - _B = 64 - _S = 256 + _B = 128 + _S = 512 _H = 4 _D = 256 diff --git a/transformer_engine/pytorch/cpu_offload.py b/transformer_engine/pytorch/cpu_offload.py index d0b8d3474e..a52ce9e03b 100644 --- a/transformer_engine/pytorch/cpu_offload.py +++ b/transformer_engine/pytorch/cpu_offload.py @@ -19,6 +19,7 @@ from .quantized_tensor import ( restore_from_saved, prepare_for_saving, + QuantizedTensor, ) @@ -255,6 +256,8 @@ def start_offload(self): Start offloading of tensors. Puts copy from GPU to CPU tasks on offload stream. Before each copy event, the offload stream waits for the event signalling that the tensor is ready to be offloaded. This event is recorded in the start_offload or push_tensor call. + + Note: tensor_list only contains regular tensors (QuantizedTensors are decomposed in push_tensor). """ self._validate_state(func_name="start_offload", allowed_states=["not_offloaded"]) self.state = "offload_started" @@ -275,19 +278,18 @@ def start_offload(self): with torch.cuda.stream(self.offload_stream): if allocate_cpu_buffers: - # empty_like is defined also for QuantizedTensors offloaded_tensor = torch.empty_like( tensor, device=torch.device("cpu"), pin_memory=True ) self.cpu_tensor_group.tensor_list.append(offloaded_tensor) else: - assert self.cpu_tensor_group.tensor_list[tensor_id].shape == tensor.shape, ( + offloaded_tensor = self.cpu_tensor_group.tensor_list[tensor_id] + assert offloaded_tensor.shape == tensor.shape, ( "CPU buffer shape does not match the offloaded tensor shape:" - f" {self.cpu_tensor_group.tensor_list[tensor_id].shape} != {tensor.shape} " - " Make sure that tensor shaped do not change between" + f" {offloaded_tensor.shape} != {tensor.shape} " + "Make sure that tensor shapes do not change between" " iterations if retain_pinned_cpu_buffers is True." ) - offloaded_tensor = self.cpu_tensor_group.tensor_list[tensor_id] offloaded_tensor.copy_(tensor, non_blocking=True) # aux is a dictionary that contains auxiliary data like information which tensors were deduplicated, @@ -318,6 +320,9 @@ def start_reload(self): """ Start reloading of tensors. It allocates new tensors on GPU and puts copy from CPU tasks on offload stream. + + Note: tensor_list only contains regular tensors (QuantizedTensors are decomposed in push_tensor + and reconstructed in pop_tensor). """ self._validate_state(func_name="start_reload", allowed_states=["offload_finished"]) self.state = "reload_started" @@ -330,7 +335,6 @@ def start_reload(self): # cannot move tensors from pool of one stream to another without # calling cudaFree and cudaMalloc again. - # empty_like is defined also for QuantizedTensors. reloaded_tensor = torch.empty_like(tensor, device=torch.device("cuda")) self.offload_stream.wait_stream(torch.cuda.current_stream()) @@ -347,16 +351,29 @@ def start_reload(self): self.bwd_gpu_tensor_group ) - def push_tensor(self, tensor: torch.Tensor) -> int | torch.Tensor: + def push_tensor(self, tensor: torch.Tensor) -> int | torch.Tensor | tuple[list, list]: """ It is called when a tensor is saved for backward pass. If tensor is offloaded, returns int representing the index of the tensor in the offloaded tensor group. If tensor is not offloaded, returns the tensor itself. + For QuantizedTensor, returns (list of push results for each component, tensor_objs) tuple. """ self._validate_state(func_name="push_tensor", allowed_states=["not_offloaded"]) if self._check_if_offload(tensor): + # For QuantizedTensor: decompose into component tensors, push each one recursively + if isinstance(tensor, QuantizedTensor): + # Make a copy because prepare_for_saving modifies the object (sets fields to None) + tensor_copy = tensor.detach() + # Inline prepare_for_saving logic - QuantizedTensor is a torch.Tensor subclass, + # so the generic prepare_for_saving would not call tensor.prepare_for_saving() + saved_tensors, tensor_obj = tensor_copy.prepare_for_saving() + push_results = [ + self.push_tensor(t) if t is not None else None for t in saved_tensors + ] + return (push_results, [tensor_obj]) + self.fwd_gpu_tensor_group.tensor_list.append(tensor) # The group is processed and offloaded at the end of the forward pass of current layer. # To enable offloading of tensors faster we use self.offload_stream and record @@ -370,23 +387,39 @@ def push_tensor(self, tensor: torch.Tensor) -> int | torch.Tensor: return len(self.fwd_gpu_tensor_group.tensor_list) - 1 return tensor - def pop_tensor(self, tensor_or_tensor_id: torch.Tensor | int) -> torch.Tensor: + def pop_tensor( + self, tensor_or_tensor_id: torch.Tensor | int | tuple[list, list] + ) -> torch.Tensor: """ It is called when a tensor is used in backward pass. Returns the tensor. If tensor was offloaded/reloaded, wait for the reload of a tensor to finish. + For QuantizedTensor (tuple input), reconstructs from component tensors. """ self._validate_state( func_name="pop_tensor", allowed_states=["not_offloaded", "reload_started"] ) - # 1. tensor not offloaded + # 1. tensor not offloaded (regular tensor returned as-is from push) if isinstance(tensor_or_tensor_id, torch.Tensor): return tensor_or_tensor_id - # 2. the layer was not offloaded at all + + # 2. QuantizedTensor case: tuple of (push_results, tensor_objs) + if isinstance(tensor_or_tensor_id, tuple): + push_results, tensor_objs = tensor_or_tensor_id + # Recursively pop each component + reloaded_tensors = [ + self.pop_tensor(pr) if pr is not None else None for pr in push_results + ] + # Inline restore_from_saved - tensor_objs[0] is the QuantizedTensor copy + tensor_obj = tensor_objs[0] + tensor_obj.restore_from_saved(reloaded_tensors) + return tensor_obj + + # 3. Regular tensor index case if self.state == "not_offloaded": return self.fwd_gpu_tensor_group.tensor_list[tensor_or_tensor_id] - # 3. the layer was offloaded + # 4. the layer was offloaded assert self.state == "reload_started" # wait for the tensor to be reloaded torch.cuda.current_stream().wait_event( @@ -406,6 +439,10 @@ def _check_if_offload(self, t: torch.Tensor) -> bool: """ Check if tensor needs to be offloaded. """ + # Only offload tensors with at least 256k elements (~1MB for float32) + if t.numel() < 256 * 1024: + return False + if ( not isinstance(t, torch.nn.Parameter) and not getattr(t, "_TE_do_not_offload", False) @@ -418,7 +455,6 @@ def _check_if_offload(self, t: torch.Tensor) -> bool: " this tensor will be skipped." ) return False - return True return False @@ -592,6 +628,12 @@ def bwd_step(self, layer_num: int): for layer in self.start_reload_map[layer_num]: self.layer_states[layer].start_reload() + def push_tensor(self, tensor: torch.Tensor) -> int | torch.Tensor: + """Push tensor - skip processing if layer won't be offloaded to reduce CPU overhead.""" + if not self.offload_layer_map.get(self.num_of_fwds, False): + return tensor + return self.layer_states[self.num_of_fwds].push_tensor(tensor) + class ManualOffloadSynchronizer(OffloadSynchronizer): """ @@ -637,7 +679,7 @@ def get_cpu_offload_context( offload_weights: bool = False, double_buffering: bool = False, # pylint: disable=unused-argument manual_synchronization: bool = False, - retain_pinned_cpu_buffers: bool = False, + retain_pinned_cpu_buffers: bool = True, offload_stream: Optional[torch.cuda.Stream] = None, ): """ diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 52ef02a347..60b931abfd 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -254,7 +254,8 @@ std::vector multi_tensor_quantize(const std::vector &ten std::vector split_quantize(const at::Tensor &tensor, const std::vector &split_sections, - std::vector quantizer_list); + std::vector quantizer_list, + bool disable_bulk_allocation = false); /*************************************************************************************************** * Bias gradient fusions diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index 4e5e5223f7..ac06841879 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -1095,7 +1095,8 @@ void split_quantize_nvfp4_impl(const TensorWrapper &input, std::vector split_quantize(const at::Tensor &tensor, const std::vector &split_sections, - std::vector quantizer_list) { + std::vector quantizer_list, + bool disable_bulk_allocation) { init_extension(); // Check number of tensors @@ -1147,22 +1148,24 @@ std::vector split_quantize(const at::Tensor &tensor, enum class QuantizationMethod { UNFUSED, FUSED_NVFP4 }; AllocationMethod allocation_method = AllocationMethod::UNFUSED; QuantizationMethod quantization_method = QuantizationMethod::UNFUSED; - if (std::all_of(quantizer_list.begin(), quantizer_list.end(), - [](const py::handle &quantizer) -> bool { - return detail::IsFloat8BlockwiseQuantizers(quantizer.ptr()); - })) { - allocation_method = AllocationMethod::BULK_FP8_BLOCKWISE; - } else if (std::all_of(quantizer_list.begin(), quantizer_list.end(), - [](const py::handle &quantizer) -> bool { - return detail::IsMXFP8Quantizers(quantizer.ptr()); - })) { - allocation_method = AllocationMethod::BULK_MXFP8; - } else if (std::all_of(quantizer_list.begin(), quantizer_list.end(), - [](const py::handle &quantizer) -> bool { - return detail::IsNVFP4Quantizers(quantizer.ptr()); - })) { - allocation_method = AllocationMethod::BULK_NVFP4; - quantization_method = QuantizationMethod::FUSED_NVFP4; + if (!disable_bulk_allocation) { + if (std::all_of(quantizer_list.begin(), quantizer_list.end(), + [](const py::handle &quantizer) -> bool { + return detail::IsFloat8BlockwiseQuantizers(quantizer.ptr()); + })) { + allocation_method = AllocationMethod::BULK_FP8_BLOCKWISE; + } else if (std::all_of(quantizer_list.begin(), quantizer_list.end(), + [](const py::handle &quantizer) -> bool { + return detail::IsMXFP8Quantizers(quantizer.ptr()); + })) { + allocation_method = AllocationMethod::BULK_MXFP8; + } else if (std::all_of(quantizer_list.begin(), quantizer_list.end(), + [](const py::handle &quantizer) -> bool { + return detail::IsNVFP4Quantizers(quantizer.ptr()); + })) { + allocation_method = AllocationMethod::BULK_NVFP4; + quantization_method = QuantizationMethod::FUSED_NVFP4; + } } // Allocate output tensors diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index e73eca7861..c5c8905294 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -248,7 +248,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "Multi-tensor quantize", py::arg("tensor_list"), py::arg("quantizer_list")); m.def("split_quantize", &transformer_engine::pytorch::split_quantize, "Split and multi-tensor quantize", py::arg("tensor"), py::arg("split_sections"), - py::arg("quantizer_list")); + py::arg("quantizer_list"), py::arg("disable_bulk_allocation") = false); m.def("te_general_grouped_gemm", &transformer_engine::pytorch::te_general_grouped_gemm, "Grouped GEMM"); m.def("fp8_transpose", &transformer_engine::pytorch::fp8_transpose, "Transpose with FP8 I/O", diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index d0a5618afb..d8d7456e56 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -143,7 +143,9 @@ def forward( inp_view = inp.reshape(-1, in_features) inputmats: list if fp8 and not debug: - inputmats = tex.split_quantize(inp_view, m_splits, input_quantizers) + inputmats = tex.split_quantize( + inp_view, m_splits, input_quantizers, disable_bulk_allocation=cpu_offloading + ) elif debug: inputmats = DebugQuantizer.multi_tensor_quantize( inp_view, input_quantizers, m_splits, activation_dtype diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index f3220d5860..b8349f84a0 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -428,7 +428,8 @@ def forward( # weights if weights are externally touched outside this module ctx.weight_object = weight - mark_not_offload(weight, weightmat, bias) + if cpu_offloading: + mark_not_offload(weight, weightmat, bias) # TODO(ksivamani): Check memory usage tensors_to_save, tensor_objects = prepare_for_saving( saved_inputmat, diff --git a/transformer_engine/pytorch/optimizers/fused_adam.py b/transformer_engine/pytorch/optimizers/fused_adam.py index 1995655c33..495056d652 100644 --- a/transformer_engine/pytorch/optimizers/fused_adam.py +++ b/transformer_engine/pytorch/optimizers/fused_adam.py @@ -14,6 +14,7 @@ from torch.distributed._tensor import DTensor import transformer_engine_torch as tex from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor, Float8Quantizer +from transformer_engine.pytorch.quantized_tensor import QuantizedTensor from .multi_tensor_apply import multi_tensor_applier @@ -372,10 +373,12 @@ def _initialize_state( store_param_remainders (bool): Store only trailing remainder bits. """ dtype = self.name_to_dtype_map[state_name] + # Handle QuantizedTensor by dequantizing first + param_for_empty = param.dequantize() if isinstance(param, QuantizedTensor) else param if store_param_remainders: - data = torch.zeros(param.shape, dtype=torch.int16, device=param.device) + data = torch.zeros_like(param_for_empty, dtype=torch.int16) else: - data = torch.empty(param.shape, dtype=dtype, device=param.device) + data = torch.empty_like(param_for_empty, dtype=dtype) if zero_buffer: data.zero_() diff --git a/transformer_engine/pytorch/quantized_tensor.py b/transformer_engine/pytorch/quantized_tensor.py index 3414581f7c..ac827e794a 100644 --- a/transformer_engine/pytorch/quantized_tensor.py +++ b/transformer_engine/pytorch/quantized_tensor.py @@ -20,11 +20,6 @@ _stride_from_shape, ) -_quantized_tensor_cpu_supported_ops = ( - torch.ops.aten.empty_like.default, - torch.ops.aten.copy_.default, -) - class QuantizedTensorStorage: r"""Base class for all TensorStorage classes. @@ -539,15 +534,6 @@ def __torch_function__(cls, func, types, args=(), kwargs=None): if kwargs is None: kwargs = {} - def check_if_cpu(arg): - if isinstance(cls, QuantizedTensor) and arg.device.type == "cpu": - assert ( - func in _quantized_tensor_cpu_supported_ops - ), f"QuantizedTensor on CPU does not support this operation: {func}" - return arg - - args = tree_map(check_if_cpu, args) - # Do not force the QuantizedTensor type on the returned tensor return torch._C._disabled_torch_function_impl(func, types, args, kwargs)