Skip to content

Commit c120bb7

Browse files
authored
nvfp4 tensor: switch to TorchAOBaseTensor (#2788)
* Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned]
1 parent 4463b79 commit c120bb7

File tree

4 files changed

+24
-65
lines changed

4 files changed

+24
-65
lines changed

test/prototype/mx_formats/test_mx_tensor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -916,8 +916,8 @@ def test_nvfp4_swizzled_scales_serialization():
916916
tensor_list, ctx = original_tensor.__tensor_flatten__()
917917

918918
# Verify swizzled flag is preserved in context
919-
assert "_is_swizzled_scales" in ctx
920-
assert ctx["_is_swizzled_scales"] == True
919+
assert NVFP4Tensor.tensor_attribute_names[3] == "_is_swizzled_scales"
920+
assert ctx[3] == True
921921

922922
# Test deserialization
923923
inner_tensors = {}

test/prototype/mx_formats/test_nvfp4_tensor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -304,8 +304,8 @@ def test_nvfp4_swizzled_scales_serialization():
304304
tensor_list, ctx = original_tensor.__tensor_flatten__()
305305

306306
# Verify swizzled flag is preserved in context
307-
assert "_is_swizzled_scales" in ctx
308-
assert ctx["_is_swizzled_scales"] == True
307+
assert NVFP4Tensor.tensor_attribute_names[3] == "_is_swizzled_scales"
308+
assert ctx[3] == True
309309

310310
# Test deserialization
311311
inner_tensors = {}

torchao/prototype/mx_formats/nvfp4_tensor.py

Lines changed: 18 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import sys
88
from enum import Enum
9-
from typing import Any, Callable, Dict, Optional
9+
from typing import Any, Dict, Optional
1010

1111
import torch
1212
from torch.utils._python_dispatch import return_and_correct_aliasing
@@ -24,7 +24,7 @@
2424
tensor_size_hp_to_fp4x2,
2525
)
2626
from torchao.prototype.mx_formats.utils import from_blocked, to_blocked
27-
from torchao.utils import ceil_div, fill_defaults
27+
from torchao.utils import TorchAOBaseTensor, ceil_div, fill_defaults
2828

2929
E4M3_EPS = torch.finfo(torch.float8_e4m3fn).tiny
3030

@@ -38,6 +38,7 @@ class NVFP4MMConfig(Enum):
3838
WEIGHT_ONLY = "weight_only"
3939

4040

41+
# TODO(future PR): move over to TorchAOBaseTensor's dispatch
4142
def implements(aten_ops):
4243
"""Register aten ops to the NVFP4 op table"""
4344

@@ -49,7 +50,7 @@ def decorator(func):
4950
return decorator
5051

5152

52-
class NVFP4Tensor(torch.Tensor):
53+
class NVFP4Tensor(TorchAOBaseTensor):
5354
"""NVIDIA FP4 (NVFP4) Tensor subclass.
5455
5556
This implements the NVIDIA variant of MX FP4 format, which uses a specific
@@ -59,20 +60,22 @@ class NVFP4Tensor(torch.Tensor):
5960
qdata: Packed FP4 data (2 values per byte)
6061
_scale_e4m3: Blockwise scales in float8_e4m3fn format (may be swizzled)
6162
_per_tensor_scale: Optional global per-tensor scale in float32 format
62-
_block_size: Block size for quantization (fixed at 16)
63-
_orig_dtype: Original tensor dtype before quantization
64-
_is_swizzled_scales: Whether scales are stored in swizzled (blocked) format
65-
mm_config: Matrix multiplication configuration
63+
_block_size (int): Block size for quantization (fixed at 16)
64+
_orig_dtype (torch.dtype): Original tensor dtype before quantization
65+
_is_swizzled_scales (bool): Whether scales are stored in swizzled (blocked) format
66+
mm_config (NVFP4MMConfig): Matrix multiplication configuration
67+
use_triton_kernel (bool): Whether to use triton kernels
6668
"""
6769

68-
qdata: torch.Tensor
69-
_scale_e4m3: torch.Tensor
70-
_per_tensor_scale: Optional[torch.Tensor]
71-
_block_size: int
72-
_orig_dtype: torch.dtype
73-
_is_swizzled_scales: bool
74-
mm_config: NVFP4MMConfig
75-
use_triton_kernel: bool
70+
tensor_data_names = ["qdata", "_scale_e4m3"]
71+
optional_tensor_data_names = ["_per_tensor_scale"]
72+
tensor_attribute_names = [
73+
"_block_size",
74+
"_orig_dtype",
75+
"mm_config",
76+
"_is_swizzled_scales",
77+
"use_triton_kernel",
78+
]
7679

7780
def __new__(
7881
cls,
@@ -173,52 +176,6 @@ def to_nvfp4(
173176
use_triton_kernel,
174177
)
175178

176-
def __tensor_flatten__(self):
177-
ctx = {
178-
"_block_size": self._block_size,
179-
"_orig_dtype": self._orig_dtype,
180-
"_is_swizzled_scales": self._is_swizzled_scales,
181-
"mm_config": self.mm_config,
182-
"use_triton_kernel": self.use_triton_kernel,
183-
}
184-
tensor_list = ["qdata", "_scale_e4m3"]
185-
if self._per_tensor_scale is not None:
186-
tensor_list.append("_per_tensor_scale")
187-
return tensor_list, ctx
188-
189-
def _apply_fn_to_data(self, fn: Callable):
190-
"""Applies a fn to all tensor components stored on this class"""
191-
tensor_names, ctx = self.__tensor_flatten__()
192-
new_tensors = {}
193-
for name in tensor_names:
194-
new_tensors[name] = fn(getattr(self, name))
195-
if "_per_tensor_scale" not in tensor_names:
196-
new_tensors["_per_tensor_scale"] = None
197-
return self.__class__.__tensor_unflatten__(
198-
new_tensors,
199-
ctx,
200-
None,
201-
None,
202-
)
203-
204-
@staticmethod
205-
def __tensor_unflatten__(
206-
inner_tensors,
207-
metadata,
208-
outer_size,
209-
outer_stride,
210-
):
211-
return NVFP4Tensor(
212-
inner_tensors["qdata"],
213-
inner_tensors["_scale_e4m3"],
214-
inner_tensors.get("_per_tensor_scale", None),
215-
metadata["_block_size"],
216-
metadata["_orig_dtype"],
217-
metadata["mm_config"],
218-
metadata.get("_is_swizzled_scales", False),
219-
metadata.get("use_triton_kernel", False),
220-
)
221-
222179
# Do not force the NVFP4Tensor type on the returned tensor
223180
__torch_function__ = torch._C._disabled_torch_function_impl
224181

torchao/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -810,6 +810,8 @@ def __tensor_flatten__(self):
810810
if maybe_tensor is not None:
811811
tensor_data_names.append(tensor_data_name)
812812

813+
# TODO(future PR): also return names of tensor attributes for easier
814+
# debugging
813815
return tensor_data_names, [
814816
getattr(self, attr) for attr in self.tensor_attribute_names
815817
]

0 commit comments

Comments
 (0)