Skip to content

Commit 4463b79

Browse files
authored
nvfp4 tensor: switch to using qdata (#2787)
* Update [ghstack-poisoned] * Update [ghstack-poisoned]
1 parent 751d7f6 commit 4463b79

File tree

3 files changed

+36
-36
lines changed

3 files changed

+36
-36
lines changed

test/prototype/mx_formats/test_mx_tensor.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -888,12 +888,12 @@ def test_nvfp4_swizzled_scales_view_semantics():
888888

889889
# Test that the sliced tensor shares storage with original for data
890890
# (Note: scales might not share storage due to swizzled layout complexity)
891-
assert sliced_tensor._data.data_ptr() == tensor._data.data_ptr()
891+
assert sliced_tensor.qdata.data_ptr() == tensor.qdata.data_ptr()
892892

893893
# Test full-width column slicing (should maintain views)
894894
full_width_slice = tensor[:, 0:K]
895895
assert full_width_slice._scale_e4m3.data_ptr() == tensor._scale_e4m3.data_ptr()
896-
assert full_width_slice._data.data_ptr() == tensor._data.data_ptr()
896+
assert full_width_slice.qdata.data_ptr() == tensor.qdata.data_ptr()
897897

898898

899899
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@@ -1011,8 +1011,8 @@ def test_triton_nvfp4_quantize_equivalence(M, N, use_per_tensor_scale, dtype):
10111011
torch.testing.assert_close(
10121012
nvfp4_pt._scale_e4m3.flatten(), nvfp4_triton._scale_e4m3.flatten()
10131013
)
1014-
pt_unpacked = unpack_uint4(nvfp4_pt._data)
1015-
triton_unpacked = unpack_uint4(nvfp4_triton._data)
1014+
pt_unpacked = unpack_uint4(nvfp4_pt.qdata)
1015+
triton_unpacked = unpack_uint4(nvfp4_triton.qdata)
10161016
torch.testing.assert_close(
10171017
pt_unpacked,
10181018
triton_unpacked,

test/prototype/mx_formats/test_nvfp4_tensor.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -276,12 +276,12 @@ def test_nvfp4_swizzled_scales_view_semantics():
276276

277277
# Test that the sliced tensor shares storage with original for data
278278
# (Note: scales might not share storage due to swizzled layout complexity)
279-
assert sliced_tensor._data.data_ptr() == tensor._data.data_ptr()
279+
assert sliced_tensor.qdata.data_ptr() == tensor.qdata.data_ptr()
280280

281281
# Test full-width column slicing (should maintain views)
282282
full_width_slice = tensor[:, 0:K]
283283
assert full_width_slice._scale_e4m3.data_ptr() == tensor._scale_e4m3.data_ptr()
284-
assert full_width_slice._data.data_ptr() == tensor._data.data_ptr()
284+
assert full_width_slice.qdata.data_ptr() == tensor.qdata.data_ptr()
285285

286286

287287
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@@ -399,8 +399,8 @@ def test_triton_nvfp4_quantize_equivalence(M, N, use_per_tensor_scale, dtype):
399399
torch.testing.assert_close(
400400
nvfp4_pt._scale_e4m3.flatten(), nvfp4_triton._scale_e4m3.flatten()
401401
)
402-
pt_unpacked = unpack_uint4(nvfp4_pt._data)
403-
triton_unpacked = unpack_uint4(nvfp4_triton._data)
402+
pt_unpacked = unpack_uint4(nvfp4_pt.qdata)
403+
triton_unpacked = unpack_uint4(nvfp4_triton.qdata)
404404
torch.testing.assert_close(
405405
pt_unpacked,
406406
triton_unpacked,

torchao/prototype/mx_formats/nvfp4_tensor.py

Lines changed: 28 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -56,18 +56,18 @@ class NVFP4Tensor(torch.Tensor):
5656
quantization algorithm for FP4 data with UE4M3 scales.
5757
5858
Attributes:
59+
qdata: Packed FP4 data (2 values per byte)
5960
_scale_e4m3: Blockwise scales in float8_e4m3fn format (may be swizzled)
6061
_per_tensor_scale: Optional global per-tensor scale in float32 format
61-
_data: Packed FP4 data (2 values per byte)
6262
_block_size: Block size for quantization (fixed at 16)
6363
_orig_dtype: Original tensor dtype before quantization
6464
_is_swizzled_scales: Whether scales are stored in swizzled (blocked) format
6565
mm_config: Matrix multiplication configuration
6666
"""
6767

68+
qdata: torch.Tensor
6869
_scale_e4m3: torch.Tensor
6970
_per_tensor_scale: Optional[torch.Tensor]
70-
_data: torch.Tensor
7171
_block_size: int
7272
_orig_dtype: torch.dtype
7373
_is_swizzled_scales: bool
@@ -76,43 +76,43 @@ class NVFP4Tensor(torch.Tensor):
7676

7777
def __new__(
7878
cls,
79+
qdata,
7980
blockwise_scales,
8081
per_tensor_scale,
81-
data_bits,
8282
block_size,
8383
orig_dtype,
8484
mm_config=NVFP4MMConfig.DYNAMIC,
8585
is_swizzled_scales=False,
8686
use_triton_kernel=False,
8787
):
8888
# FP4 tensor size handling two paths, contiguous or not
89-
new_size = data_bits.size()
89+
new_size = qdata.size()
9090

9191
new_size = tensor_size_fp4x2_to_hp(
9292
new_size,
93-
data_bits.stride(0) > data_bits.stride(1),
93+
qdata.stride(0) > qdata.stride(1),
9494
)
9595

9696
self = torch.Tensor._make_wrapper_subclass(
9797
cls,
9898
new_size,
9999
dtype=orig_dtype,
100-
device=data_bits.device,
100+
device=qdata.device,
101101
requires_grad=False,
102102
)
103103

104104
self._scale_e4m3 = blockwise_scales
105105
self._is_swizzled_scales = is_swizzled_scales
106106
self._per_tensor_scale = per_tensor_scale
107-
self._data = data_bits
107+
self.qdata = qdata
108108
self._block_size = block_size
109109
self._orig_dtype = orig_dtype
110110
self.mm_config = mm_config
111111
self.use_triton_kernel = use_triton_kernel
112112
return self
113113

114114
def __repr__(self):
115-
return f"NVFP4Tensor: blockwise_scales: {self._scale_e4m3}, per_tensor_scale: {self._per_tensor_scale}, d: {self._data}, d_hp: {self.to_dtype(self._orig_dtype)}"
115+
return f"NVFP4Tensor: blockwise_scales: {self._scale_e4m3}, per_tensor_scale: {self._per_tensor_scale}, d: {self.qdata}, d_hp: {self.to_dtype(self._orig_dtype)}"
116116

117117
@classmethod
118118
def __torch_dispatch__(cls, func, types, args, kwargs=None):
@@ -163,9 +163,9 @@ def to_nvfp4(
163163
).flatten()
164164

165165
return NVFP4Tensor(
166+
data_lp,
166167
blockwise_scales,
167168
per_tensor_scale,
168-
data_lp,
169169
block_size,
170170
data_hp.dtype,
171171
mm_config,
@@ -181,7 +181,7 @@ def __tensor_flatten__(self):
181181
"mm_config": self.mm_config,
182182
"use_triton_kernel": self.use_triton_kernel,
183183
}
184-
tensor_list = ["_scale_e4m3", "_data"]
184+
tensor_list = ["qdata", "_scale_e4m3"]
185185
if self._per_tensor_scale is not None:
186186
tensor_list.append("_per_tensor_scale")
187187
return tensor_list, ctx
@@ -209,9 +209,9 @@ def __tensor_unflatten__(
209209
outer_stride,
210210
):
211211
return NVFP4Tensor(
212+
inner_tensors["qdata"],
212213
inner_tensors["_scale_e4m3"],
213214
inner_tensors.get("_per_tensor_scale", None),
214-
inner_tensors["_data"],
215215
metadata["_block_size"],
216216
metadata["_orig_dtype"],
217217
metadata["mm_config"],
@@ -231,12 +231,12 @@ def to_dtype(self, target_dtype: torch.dtype) -> torch.Tensor:
231231
Returns:
232232
torch.Tensor: Dequantized tensor in the target dtype
233233
"""
234-
is_transposed = self._data.stride(0) < self._data.stride(1)
234+
is_transposed = self.qdata.stride(0) < self.qdata.stride(1)
235235
if is_transposed:
236236
M, K = self.shape[1], self.shape[0]
237237
else:
238238
M, K = self.shape[0], self.shape[1]
239-
data = self._data.t() if is_transposed else self._data
239+
data = self.qdata.t() if is_transposed else self.qdata
240240
data_unpacked = unpack_uint4(data.contiguous().view(torch.uint8))
241241
data_f32 = f4_unpacked_to_f32(data_unpacked)
242242

@@ -256,7 +256,7 @@ def get_hp_scales(self) -> torch.Tensor:
256256
Returns:
257257
torch.Tensor: Scales of the NVFP4Tensor
258258
"""
259-
is_transposed = self._data.stride(0) < self._data.stride(1)
259+
is_transposed = self.qdata.stride(0) < self.qdata.stride(1)
260260
if is_transposed:
261261
M, K = self.shape[1], self.shape[0]
262262
else:
@@ -296,7 +296,7 @@ def _same_metadata(cls, self: "NVFP4Tensor", src: "NVFP4Tensor") -> bool:
296296
and self._is_swizzled_scales == src._is_swizzled_scales
297297
and self._scale_e4m3.shape == src._scale_e4m3.shape
298298
and per_tensor_scale_equal
299-
and self._data.shape == src._data.shape
299+
and self.qdata.shape == src.qdata.shape
300300
)
301301

302302

@@ -379,7 +379,7 @@ def nvfp4_slice(func, types, args, kwargs):
379379
if step != 1:
380380
raise ValueError("Only support aten.slice with step=1")
381381

382-
assert x._data.is_contiguous(), "Only support contiguous data for now"
382+
assert x.qdata.is_contiguous(), "Only support contiguous data for now"
383383

384384
M, K = x.shape[0], x.shape[1]
385385

@@ -422,7 +422,7 @@ def nvfp4_slice(func, types, args, kwargs):
422422
)
423423

424424
sliced_scale = aten.slice.Tensor(x._scale_e4m3, 0, start_idx, end_idx, 1)
425-
sliced_data = aten.slice.Tensor(x._data, 0, start, end, step)
425+
sliced_data = aten.slice.Tensor(x.qdata, 0, start, end, step)
426426

427427
elif dim == 1:
428428
# Column slicing
@@ -485,7 +485,7 @@ def nvfp4_slice(func, types, args, kwargs):
485485
packed_start = None if start is None else start // 2
486486
packed_end = None if end is None else end // 2
487487
sliced_data = aten.slice.Tensor(
488-
x._data, dim, packed_start, packed_end, step
488+
x.qdata, dim, packed_start, packed_end, step
489489
)
490490

491491
else:
@@ -498,7 +498,7 @@ def nvfp4_slice(func, types, args, kwargs):
498498

499499
if dim == 0:
500500
sliced_scale = aten.slice.Tensor(scale_shaped, dim, start, end, step)
501-
sliced_data = aten.slice.Tensor(x._data, dim, start, end, step)
501+
sliced_data = aten.slice.Tensor(x.qdata, dim, start, end, step)
502502

503503
elif dim == 1:
504504
if start is not None:
@@ -518,7 +518,7 @@ def nvfp4_slice(func, types, args, kwargs):
518518
packed_start = None if start is None else start // 2
519519
packed_end = None if end is None else end // 2
520520
sliced_data = aten.slice.Tensor(
521-
x._data, dim, packed_start, packed_end, step
521+
x.qdata, dim, packed_start, packed_end, step
522522
)
523523

524524
start_block = 0 if start is None else start // x._block_size
@@ -531,9 +531,9 @@ def nvfp4_slice(func, types, args, kwargs):
531531

532532
# Create result tensor
533533
result = NVFP4Tensor(
534+
sliced_data,
534535
sliced_scale,
535536
x._per_tensor_scale,
536-
sliced_data,
537537
x._block_size,
538538
x._orig_dtype,
539539
x.mm_config,
@@ -549,9 +549,9 @@ def nvfp4_t(func, types, args, kwargs):
549549
# For now, only transpose(input, 0, 1) is supported.
550550
old = args[0]
551551
new = NVFP4Tensor(
552+
old.qdata.t(),
552553
old._scale_e4m3,
553554
old._per_tensor_scale,
554-
old._data.t(),
555555
old._block_size,
556556
old._orig_dtype,
557557
old.mm_config,
@@ -563,14 +563,14 @@ def nvfp4_t(func, types, args, kwargs):
563563

564564
@implements([aten.view.default])
565565
def nvfp4_view_op(func, types, args, kwargs):
566-
data = args[0]._data
566+
data = args[0].qdata
567567
new_size = args[1]
568568
new_size = tensor_size_hp_to_fp4x2(new_size, data.is_contiguous())
569569
new_data = func(data, new_size, *args[2:], **kwargs)
570570
return NVFP4Tensor(
571+
new_data,
571572
args[0]._scale_e4m3,
572573
args[0]._per_tensor_scale,
573-
new_data,
574574
args[0]._block_size,
575575
args[0]._orig_dtype,
576576
args[0].mm_config,
@@ -586,8 +586,8 @@ def _addmm_nvfp4_dispatch(
586586
Core implementation shared between nvfp4_mm, nvfp4_addmm, and nvfp4_linear.
587587
The only difference is whether bias is None or not.
588588
"""
589-
assert a._data.is_contiguous()
590-
assert b._data.t().is_contiguous()
589+
assert a.qdata.is_contiguous()
590+
assert b.qdata.t().is_contiguous()
591591
assert a._block_size == 16, f"NVFP4 requires block_size=16, got {a._block_size}"
592592
assert b._block_size == 16, f"NVFP4 requires block_size=16, got {b._block_size}"
593593

@@ -623,8 +623,8 @@ def _addmm_nvfp4_dispatch(
623623
# should_add_bias_separately = bias is not None
624624

625625
result = torch._scaled_mm(
626-
a._data.view(torch.float4_e2m1fn_x2),
627-
b._data.view(torch.float4_e2m1fn_x2),
626+
a.qdata.view(torch.float4_e2m1fn_x2),
627+
b.qdata.view(torch.float4_e2m1fn_x2),
628628
a_scale_blocked.view(torch.float8_e4m3fn),
629629
b_scale_blocked.view(torch.float8_e4m3fn),
630630
bias=None if should_add_bias_separately else bias,

0 commit comments

Comments
 (0)