Skip to content

Commit fe8fad5

Browse files
[PyTorch] Bunch of fixes for cpu offloading (#2535)
* code drop Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com> * fix Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com> * fix Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com> * fix Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * test fix Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com> * fixes Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com> * fix Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com> --------- Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 2f8ae81 commit fe8fad5

File tree

9 files changed

+103
-53
lines changed

9 files changed

+103
-53
lines changed

tests/pytorch/test_cpu_offloading.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,13 @@
5454

5555

5656
class Utils:
57+
# Tensor used for simulating long-running GPU work in long_job()
5758
tensor1 = torch.randn((1024, 1024), device="cuda", dtype=torch.bfloat16)
58-
_B = 64
59-
_S = 256
59+
# Test tensor dimensions: _B x _S x _D = 128 x 512 x 256 = 16,777,216 elements
60+
# This exceeds the 256K element threshold for offloading (cpu_offload.py line 443).
61+
# For quantized tensors, scale_inv tensors (~524K elements for block scaling) also exceed threshold.
62+
_B = 128
63+
_S = 512
6064
_H = 4
6165
_D = 256
6266

@@ -395,6 +399,9 @@ def test_multiple_tensor_offload(self, recipe):
395399
offload_synchronizer.push_tensor(x1)
396400
offload_synchronizer.push_tensor(x1)
397401
offload_synchronizer.push_tensor(x1)
402+
# Verify x1 is not corrupted after pushing (important for QuantizedTensor)
403+
if recipe is not None:
404+
x1.dequantize() # Should not raise - tensor should still be valid
398405
offload_synchronizer.fwd_step()
399406
# Only one copy of tensor on cpu is allocated.
400407
assert Utils.get_cpu_memory_mb() == pytest.approx(init_cpu_memory + 1 * x_size, 0.1)

transformer_engine/pytorch/cpu_offload.py

Lines changed: 58 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from .quantized_tensor import (
2020
restore_from_saved,
2121
prepare_for_saving,
22+
QuantizedTensor,
2223
)
2324

2425

@@ -255,6 +256,8 @@ def start_offload(self):
255256
Start offloading of tensors. Puts copy from GPU to CPU tasks on offload stream.
256257
Before each copy event, the offload stream waits for the event signalling that the tensor is ready to be offloaded.
257258
This event is recorded in the start_offload or push_tensor call.
259+
260+
Note: tensor_list only contains regular tensors (QuantizedTensors are decomposed in push_tensor).
258261
"""
259262
self._validate_state(func_name="start_offload", allowed_states=["not_offloaded"])
260263
self.state = "offload_started"
@@ -275,19 +278,18 @@ def start_offload(self):
275278

276279
with torch.cuda.stream(self.offload_stream):
277280
if allocate_cpu_buffers:
278-
# empty_like is defined also for QuantizedTensors
279281
offloaded_tensor = torch.empty_like(
280282
tensor, device=torch.device("cpu"), pin_memory=True
281283
)
282284
self.cpu_tensor_group.tensor_list.append(offloaded_tensor)
283285
else:
284-
assert self.cpu_tensor_group.tensor_list[tensor_id].shape == tensor.shape, (
286+
offloaded_tensor = self.cpu_tensor_group.tensor_list[tensor_id]
287+
assert offloaded_tensor.shape == tensor.shape, (
285288
"CPU buffer shape does not match the offloaded tensor shape:"
286-
f" {self.cpu_tensor_group.tensor_list[tensor_id].shape} != {tensor.shape} "
287-
" Make sure that tensor shaped do not change between"
289+
f" {offloaded_tensor.shape} != {tensor.shape} "
290+
"Make sure that tensor shapes do not change between"
288291
" iterations if retain_pinned_cpu_buffers is True."
289292
)
290-
offloaded_tensor = self.cpu_tensor_group.tensor_list[tensor_id]
291293
offloaded_tensor.copy_(tensor, non_blocking=True)
292294

293295
# aux is a dictionary that contains auxiliary data like information which tensors were deduplicated,
@@ -318,6 +320,9 @@ def start_reload(self):
318320
"""
319321
Start reloading of tensors.
320322
It allocates new tensors on GPU and puts copy from CPU tasks on offload stream.
323+
324+
Note: tensor_list only contains regular tensors (QuantizedTensors are decomposed in push_tensor
325+
and reconstructed in pop_tensor).
321326
"""
322327
self._validate_state(func_name="start_reload", allowed_states=["offload_finished"])
323328
self.state = "reload_started"
@@ -330,7 +335,6 @@ def start_reload(self):
330335
# cannot move tensors from pool of one stream to another without
331336
# calling cudaFree and cudaMalloc again.
332337

333-
# empty_like is defined also for QuantizedTensors.
334338
reloaded_tensor = torch.empty_like(tensor, device=torch.device("cuda"))
335339
self.offload_stream.wait_stream(torch.cuda.current_stream())
336340

@@ -347,16 +351,29 @@ def start_reload(self):
347351
self.bwd_gpu_tensor_group
348352
)
349353

350-
def push_tensor(self, tensor: torch.Tensor) -> int | torch.Tensor:
354+
def push_tensor(self, tensor: torch.Tensor) -> int | torch.Tensor | tuple[list, list]:
351355
"""
352356
It is called when a tensor is saved for backward pass.
353357
354358
If tensor is offloaded, returns int representing the index of the tensor in the offloaded tensor group.
355359
If tensor is not offloaded, returns the tensor itself.
360+
For QuantizedTensor, returns (list of push results for each component, tensor_objs) tuple.
356361
"""
357362
self._validate_state(func_name="push_tensor", allowed_states=["not_offloaded"])
358363

359364
if self._check_if_offload(tensor):
365+
# For QuantizedTensor: decompose into component tensors, push each one recursively
366+
if isinstance(tensor, QuantizedTensor):
367+
# Make a copy because prepare_for_saving modifies the object (sets fields to None)
368+
tensor_copy = tensor.detach()
369+
# Inline prepare_for_saving logic - QuantizedTensor is a torch.Tensor subclass,
370+
# so the generic prepare_for_saving would not call tensor.prepare_for_saving()
371+
saved_tensors, tensor_obj = tensor_copy.prepare_for_saving()
372+
push_results = [
373+
self.push_tensor(t) if t is not None else None for t in saved_tensors
374+
]
375+
return (push_results, [tensor_obj])
376+
360377
self.fwd_gpu_tensor_group.tensor_list.append(tensor)
361378
# The group is processed and offloaded at the end of the forward pass of current layer.
362379
# To enable offloading of tensors faster we use self.offload_stream and record
@@ -370,23 +387,39 @@ def push_tensor(self, tensor: torch.Tensor) -> int | torch.Tensor:
370387
return len(self.fwd_gpu_tensor_group.tensor_list) - 1
371388
return tensor
372389

373-
def pop_tensor(self, tensor_or_tensor_id: torch.Tensor | int) -> torch.Tensor:
390+
def pop_tensor(
391+
self, tensor_or_tensor_id: torch.Tensor | int | tuple[list, list]
392+
) -> torch.Tensor:
374393
"""
375394
It is called when a tensor is used in backward pass.
376395
Returns the tensor. If tensor was offloaded/reloaded, wait for the reload of a tensor to finish.
396+
For QuantizedTensor (tuple input), reconstructs from component tensors.
377397
"""
378398
self._validate_state(
379399
func_name="pop_tensor", allowed_states=["not_offloaded", "reload_started"]
380400
)
381401

382-
# 1. tensor not offloaded
402+
# 1. tensor not offloaded (regular tensor returned as-is from push)
383403
if isinstance(tensor_or_tensor_id, torch.Tensor):
384404
return tensor_or_tensor_id
385-
# 2. the layer was not offloaded at all
405+
406+
# 2. QuantizedTensor case: tuple of (push_results, tensor_objs)
407+
if isinstance(tensor_or_tensor_id, tuple):
408+
push_results, tensor_objs = tensor_or_tensor_id
409+
# Recursively pop each component
410+
reloaded_tensors = [
411+
self.pop_tensor(pr) if pr is not None else None for pr in push_results
412+
]
413+
# Inline restore_from_saved - tensor_objs[0] is the QuantizedTensor copy
414+
tensor_obj = tensor_objs[0]
415+
tensor_obj.restore_from_saved(reloaded_tensors)
416+
return tensor_obj
417+
418+
# 3. Regular tensor index case
386419
if self.state == "not_offloaded":
387420
return self.fwd_gpu_tensor_group.tensor_list[tensor_or_tensor_id]
388421

389-
# 3. the layer was offloaded
422+
# 4. the layer was offloaded
390423
assert self.state == "reload_started"
391424
# wait for the tensor to be reloaded
392425
torch.cuda.current_stream().wait_event(
@@ -406,6 +439,10 @@ def _check_if_offload(self, t: torch.Tensor) -> bool:
406439
"""
407440
Check if tensor needs to be offloaded.
408441
"""
442+
# Only offload tensors with at least 256k elements (~1MB for float32)
443+
if t.numel() < 256 * 1024:
444+
return False
445+
409446
if (
410447
not isinstance(t, torch.nn.Parameter)
411448
and not getattr(t, "_TE_do_not_offload", False)
@@ -418,7 +455,6 @@ def _check_if_offload(self, t: torch.Tensor) -> bool:
418455
" this tensor will be skipped."
419456
)
420457
return False
421-
422458
return True
423459
return False
424460

@@ -488,11 +524,13 @@ def bwd_step(self, layer_num: int):
488524
self.previous_bwd_layer_id = layer_num
489525
self.current_layer_id = layer_num
490526

491-
def push_tensor(self, tensor: torch.Tensor) -> int | torch.Tensor:
527+
def push_tensor(self, tensor: torch.Tensor) -> int | torch.Tensor | tuple[list, list]:
492528
"""Default push tensor method"""
493529
return self.layer_states[self.num_of_fwds].push_tensor(tensor)
494530

495-
def pop_tensor(self, tensor_or_tensor_id: torch.Tensor | int) -> torch.Tensor:
531+
def pop_tensor(
532+
self, tensor_or_tensor_id: torch.Tensor | int | tuple[list, list]
533+
) -> torch.Tensor:
496534
"""Default pop tensor method"""
497535
return self.layer_states[self.current_layer_id].pop_tensor(tensor_or_tensor_id)
498536

@@ -592,6 +630,12 @@ def bwd_step(self, layer_num: int):
592630
for layer in self.start_reload_map[layer_num]:
593631
self.layer_states[layer].start_reload()
594632

633+
def push_tensor(self, tensor: torch.Tensor) -> int | torch.Tensor | tuple[list, list]:
634+
"""Push tensor - skip processing if layer won't be offloaded to reduce CPU overhead."""
635+
if not self.offload_layer_map.get(self.num_of_fwds, False):
636+
return tensor
637+
return self.layer_states[self.num_of_fwds].push_tensor(tensor)
638+
595639

596640
class ManualOffloadSynchronizer(OffloadSynchronizer):
597641
"""

transformer_engine/pytorch/csrc/extensions.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,8 @@ std::vector<py::object> multi_tensor_quantize(const std::vector<at::Tensor> &ten
254254

255255
std::vector<py::object> split_quantize(const at::Tensor &tensor,
256256
const std::vector<size_t> &split_sections,
257-
std::vector<py::handle> quantizer_list);
257+
std::vector<py::handle> quantizer_list,
258+
bool disable_bulk_allocation = false);
258259

259260
/***************************************************************************************************
260261
* Bias gradient fusions

transformer_engine/pytorch/csrc/extensions/cast.cpp

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1095,7 +1095,8 @@ void split_quantize_nvfp4_impl(const TensorWrapper &input,
10951095

10961096
std::vector<py::object> split_quantize(const at::Tensor &tensor,
10971097
const std::vector<size_t> &split_sections,
1098-
std::vector<py::handle> quantizer_list) {
1098+
std::vector<py::handle> quantizer_list,
1099+
bool disable_bulk_allocation) {
10991100
init_extension();
11001101

11011102
// Check number of tensors
@@ -1147,22 +1148,24 @@ std::vector<py::object> split_quantize(const at::Tensor &tensor,
11471148
enum class QuantizationMethod { UNFUSED, FUSED_NVFP4 };
11481149
AllocationMethod allocation_method = AllocationMethod::UNFUSED;
11491150
QuantizationMethod quantization_method = QuantizationMethod::UNFUSED;
1150-
if (std::all_of(quantizer_list.begin(), quantizer_list.end(),
1151-
[](const py::handle &quantizer) -> bool {
1152-
return detail::IsFloat8BlockwiseQuantizers(quantizer.ptr());
1153-
})) {
1154-
allocation_method = AllocationMethod::BULK_FP8_BLOCKWISE;
1155-
} else if (std::all_of(quantizer_list.begin(), quantizer_list.end(),
1156-
[](const py::handle &quantizer) -> bool {
1157-
return detail::IsMXFP8Quantizers(quantizer.ptr());
1158-
})) {
1159-
allocation_method = AllocationMethod::BULK_MXFP8;
1160-
} else if (std::all_of(quantizer_list.begin(), quantizer_list.end(),
1161-
[](const py::handle &quantizer) -> bool {
1162-
return detail::IsNVFP4Quantizers(quantizer.ptr());
1163-
})) {
1164-
allocation_method = AllocationMethod::BULK_NVFP4;
1165-
quantization_method = QuantizationMethod::FUSED_NVFP4;
1151+
if (!disable_bulk_allocation) {
1152+
if (std::all_of(quantizer_list.begin(), quantizer_list.end(),
1153+
[](const py::handle &quantizer) -> bool {
1154+
return detail::IsFloat8BlockwiseQuantizers(quantizer.ptr());
1155+
})) {
1156+
allocation_method = AllocationMethod::BULK_FP8_BLOCKWISE;
1157+
} else if (std::all_of(quantizer_list.begin(), quantizer_list.end(),
1158+
[](const py::handle &quantizer) -> bool {
1159+
return detail::IsMXFP8Quantizers(quantizer.ptr());
1160+
})) {
1161+
allocation_method = AllocationMethod::BULK_MXFP8;
1162+
} else if (std::all_of(quantizer_list.begin(), quantizer_list.end(),
1163+
[](const py::handle &quantizer) -> bool {
1164+
return detail::IsNVFP4Quantizers(quantizer.ptr());
1165+
})) {
1166+
allocation_method = AllocationMethod::BULK_NVFP4;
1167+
quantization_method = QuantizationMethod::FUSED_NVFP4;
1168+
}
11661169
}
11671170

11681171
// Allocate output tensors

transformer_engine/pytorch/csrc/extensions/pybind.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
248248
"Multi-tensor quantize", py::arg("tensor_list"), py::arg("quantizer_list"));
249249
m.def("split_quantize", &transformer_engine::pytorch::split_quantize,
250250
"Split and multi-tensor quantize", py::arg("tensor"), py::arg("split_sections"),
251-
py::arg("quantizer_list"));
251+
py::arg("quantizer_list"), py::arg("disable_bulk_allocation") = false);
252252
m.def("te_general_grouped_gemm", &transformer_engine::pytorch::te_general_grouped_gemm,
253253
"Grouped GEMM");
254254
m.def("fp8_transpose", &transformer_engine::pytorch::fp8_transpose, "Transpose with FP8 I/O",

transformer_engine/pytorch/module/grouped_linear.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,12 @@ def forward(
143143
inp_view = inp.reshape(-1, in_features)
144144
inputmats: list
145145
if fp8 and not debug:
146-
inputmats = tex.split_quantize(inp_view, m_splits, input_quantizers)
146+
# Disable bulk allocation when CPU offloading is active: offloading skips small
147+
# tensors (like scales), but bulk allocation shares storage across all tensors,
148+
# so if scales can't be offloaded, nothing in the group can be offloaded.
149+
inputmats = tex.split_quantize(
150+
inp_view, m_splits, input_quantizers, disable_bulk_allocation=cpu_offloading
151+
)
147152
elif debug:
148153
inputmats = DebugQuantizer.multi_tensor_quantize(
149154
inp_view, input_quantizers, m_splits, activation_dtype

transformer_engine/pytorch/module/linear.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -428,7 +428,8 @@ def forward(
428428
# weights if weights are externally touched outside this module
429429
ctx.weight_object = weight
430430

431-
mark_not_offload(weight, weightmat, bias)
431+
if cpu_offloading:
432+
mark_not_offload(weight, weightmat, bias)
432433
# TODO(ksivamani): Check memory usage
433434
tensors_to_save, tensor_objects = prepare_for_saving(
434435
saved_inputmat,

transformer_engine/pytorch/optimizers/fused_adam.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from torch.distributed._tensor import DTensor
1515
import transformer_engine_torch as tex
1616
from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor, Float8Quantizer
17+
from transformer_engine.pytorch.quantized_tensor import QuantizedTensor
1718
from .multi_tensor_apply import multi_tensor_applier
1819

1920

@@ -372,10 +373,12 @@ def _initialize_state(
372373
store_param_remainders (bool): Store only trailing remainder bits.
373374
"""
374375
dtype = self.name_to_dtype_map[state_name]
376+
# Handle QuantizedTensor by dequantizing first
377+
param_for_empty = param.dequantize() if isinstance(param, QuantizedTensor) else param
375378
if store_param_remainders:
376-
data = torch.zeros(param.shape, dtype=torch.int16, device=param.device)
379+
data = torch.zeros_like(param_for_empty, dtype=torch.int16)
377380
else:
378-
data = torch.empty(param.shape, dtype=dtype, device=param.device)
381+
data = torch.empty_like(param_for_empty, dtype=dtype)
379382
if zero_buffer:
380383
data.zero_()
381384

transformer_engine/pytorch/quantized_tensor.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,6 @@
2020
_stride_from_shape,
2121
)
2222

23-
_quantized_tensor_cpu_supported_ops = (
24-
torch.ops.aten.empty_like.default,
25-
torch.ops.aten.copy_.default,
26-
)
27-
2823

2924
class QuantizedTensorStorage:
3025
r"""Base class for all TensorStorage classes.
@@ -539,15 +534,6 @@ def __torch_function__(cls, func, types, args=(), kwargs=None):
539534
if kwargs is None:
540535
kwargs = {}
541536

542-
def check_if_cpu(arg):
543-
if isinstance(cls, QuantizedTensor) and arg.device.type == "cpu":
544-
assert (
545-
func in _quantized_tensor_cpu_supported_ops
546-
), f"QuantizedTensor on CPU does not support this operation: {func}"
547-
return arg
548-
549-
args = tree_map(check_if_cpu, args)
550-
551537
# Do not force the QuantizedTensor type on the returned tensor
552538
return torch._C._disabled_torch_function_impl(func, types, args, kwargs)
553539

0 commit comments

Comments
 (0)