@@ -56,18 +56,18 @@ class NVFP4Tensor(torch.Tensor):
56
56
quantization algorithm for FP4 data with UE4M3 scales.
57
57
58
58
Attributes:
59
+ qdata: Packed FP4 data (2 values per byte)
59
60
_scale_e4m3: Blockwise scales in float8_e4m3fn format (may be swizzled)
60
61
_per_tensor_scale: Optional global per-tensor scale in float32 format
61
- _data: Packed FP4 data (2 values per byte)
62
62
_block_size: Block size for quantization (fixed at 16)
63
63
_orig_dtype: Original tensor dtype before quantization
64
64
_is_swizzled_scales: Whether scales are stored in swizzled (blocked) format
65
65
mm_config: Matrix multiplication configuration
66
66
"""
67
67
68
+ qdata : torch .Tensor
68
69
_scale_e4m3 : torch .Tensor
69
70
_per_tensor_scale : Optional [torch .Tensor ]
70
- _data : torch .Tensor
71
71
_block_size : int
72
72
_orig_dtype : torch .dtype
73
73
_is_swizzled_scales : bool
@@ -76,43 +76,43 @@ class NVFP4Tensor(torch.Tensor):
76
76
77
77
def __new__ (
78
78
cls ,
79
+ qdata ,
79
80
blockwise_scales ,
80
81
per_tensor_scale ,
81
- data_bits ,
82
82
block_size ,
83
83
orig_dtype ,
84
84
mm_config = NVFP4MMConfig .DYNAMIC ,
85
85
is_swizzled_scales = False ,
86
86
use_triton_kernel = False ,
87
87
):
88
88
# FP4 tensor size handling two paths, contiguous or not
89
- new_size = data_bits .size ()
89
+ new_size = qdata .size ()
90
90
91
91
new_size = tensor_size_fp4x2_to_hp (
92
92
new_size ,
93
- data_bits .stride (0 ) > data_bits .stride (1 ),
93
+ qdata .stride (0 ) > qdata .stride (1 ),
94
94
)
95
95
96
96
self = torch .Tensor ._make_wrapper_subclass (
97
97
cls ,
98
98
new_size ,
99
99
dtype = orig_dtype ,
100
- device = data_bits .device ,
100
+ device = qdata .device ,
101
101
requires_grad = False ,
102
102
)
103
103
104
104
self ._scale_e4m3 = blockwise_scales
105
105
self ._is_swizzled_scales = is_swizzled_scales
106
106
self ._per_tensor_scale = per_tensor_scale
107
- self ._data = data_bits
107
+ self .qdata = qdata
108
108
self ._block_size = block_size
109
109
self ._orig_dtype = orig_dtype
110
110
self .mm_config = mm_config
111
111
self .use_triton_kernel = use_triton_kernel
112
112
return self
113
113
114
114
def __repr__ (self ):
115
- return f"NVFP4Tensor: blockwise_scales: { self ._scale_e4m3 } , per_tensor_scale: { self ._per_tensor_scale } , d: { self ._data } , d_hp: { self .to_dtype (self ._orig_dtype )} "
115
+ return f"NVFP4Tensor: blockwise_scales: { self ._scale_e4m3 } , per_tensor_scale: { self ._per_tensor_scale } , d: { self .qdata } , d_hp: { self .to_dtype (self ._orig_dtype )} "
116
116
117
117
@classmethod
118
118
def __torch_dispatch__ (cls , func , types , args , kwargs = None ):
@@ -163,9 +163,9 @@ def to_nvfp4(
163
163
).flatten ()
164
164
165
165
return NVFP4Tensor (
166
+ data_lp ,
166
167
blockwise_scales ,
167
168
per_tensor_scale ,
168
- data_lp ,
169
169
block_size ,
170
170
data_hp .dtype ,
171
171
mm_config ,
@@ -181,7 +181,7 @@ def __tensor_flatten__(self):
181
181
"mm_config" : self .mm_config ,
182
182
"use_triton_kernel" : self .use_triton_kernel ,
183
183
}
184
- tensor_list = ["_scale_e4m3 " , "_data " ]
184
+ tensor_list = ["qdata " , "_scale_e4m3 " ]
185
185
if self ._per_tensor_scale is not None :
186
186
tensor_list .append ("_per_tensor_scale" )
187
187
return tensor_list , ctx
@@ -209,9 +209,9 @@ def __tensor_unflatten__(
209
209
outer_stride ,
210
210
):
211
211
return NVFP4Tensor (
212
+ inner_tensors ["qdata" ],
212
213
inner_tensors ["_scale_e4m3" ],
213
214
inner_tensors .get ("_per_tensor_scale" , None ),
214
- inner_tensors ["_data" ],
215
215
metadata ["_block_size" ],
216
216
metadata ["_orig_dtype" ],
217
217
metadata ["mm_config" ],
@@ -231,12 +231,12 @@ def to_dtype(self, target_dtype: torch.dtype) -> torch.Tensor:
231
231
Returns:
232
232
torch.Tensor: Dequantized tensor in the target dtype
233
233
"""
234
- is_transposed = self ._data .stride (0 ) < self ._data .stride (1 )
234
+ is_transposed = self .qdata .stride (0 ) < self .qdata .stride (1 )
235
235
if is_transposed :
236
236
M , K = self .shape [1 ], self .shape [0 ]
237
237
else :
238
238
M , K = self .shape [0 ], self .shape [1 ]
239
- data = self ._data .t () if is_transposed else self ._data
239
+ data = self .qdata .t () if is_transposed else self .qdata
240
240
data_unpacked = unpack_uint4 (data .contiguous ().view (torch .uint8 ))
241
241
data_f32 = f4_unpacked_to_f32 (data_unpacked )
242
242
@@ -256,7 +256,7 @@ def get_hp_scales(self) -> torch.Tensor:
256
256
Returns:
257
257
torch.Tensor: Scales of the NVFP4Tensor
258
258
"""
259
- is_transposed = self ._data .stride (0 ) < self ._data .stride (1 )
259
+ is_transposed = self .qdata .stride (0 ) < self .qdata .stride (1 )
260
260
if is_transposed :
261
261
M , K = self .shape [1 ], self .shape [0 ]
262
262
else :
@@ -296,7 +296,7 @@ def _same_metadata(cls, self: "NVFP4Tensor", src: "NVFP4Tensor") -> bool:
296
296
and self ._is_swizzled_scales == src ._is_swizzled_scales
297
297
and self ._scale_e4m3 .shape == src ._scale_e4m3 .shape
298
298
and per_tensor_scale_equal
299
- and self ._data .shape == src ._data .shape
299
+ and self .qdata .shape == src .qdata .shape
300
300
)
301
301
302
302
@@ -379,7 +379,7 @@ def nvfp4_slice(func, types, args, kwargs):
379
379
if step != 1 :
380
380
raise ValueError ("Only support aten.slice with step=1" )
381
381
382
- assert x ._data .is_contiguous (), "Only support contiguous data for now"
382
+ assert x .qdata .is_contiguous (), "Only support contiguous data for now"
383
383
384
384
M , K = x .shape [0 ], x .shape [1 ]
385
385
@@ -422,7 +422,7 @@ def nvfp4_slice(func, types, args, kwargs):
422
422
)
423
423
424
424
sliced_scale = aten .slice .Tensor (x ._scale_e4m3 , 0 , start_idx , end_idx , 1 )
425
- sliced_data = aten .slice .Tensor (x ._data , 0 , start , end , step )
425
+ sliced_data = aten .slice .Tensor (x .qdata , 0 , start , end , step )
426
426
427
427
elif dim == 1 :
428
428
# Column slicing
@@ -485,7 +485,7 @@ def nvfp4_slice(func, types, args, kwargs):
485
485
packed_start = None if start is None else start // 2
486
486
packed_end = None if end is None else end // 2
487
487
sliced_data = aten .slice .Tensor (
488
- x ._data , dim , packed_start , packed_end , step
488
+ x .qdata , dim , packed_start , packed_end , step
489
489
)
490
490
491
491
else :
@@ -498,7 +498,7 @@ def nvfp4_slice(func, types, args, kwargs):
498
498
499
499
if dim == 0 :
500
500
sliced_scale = aten .slice .Tensor (scale_shaped , dim , start , end , step )
501
- sliced_data = aten .slice .Tensor (x ._data , dim , start , end , step )
501
+ sliced_data = aten .slice .Tensor (x .qdata , dim , start , end , step )
502
502
503
503
elif dim == 1 :
504
504
if start is not None :
@@ -518,7 +518,7 @@ def nvfp4_slice(func, types, args, kwargs):
518
518
packed_start = None if start is None else start // 2
519
519
packed_end = None if end is None else end // 2
520
520
sliced_data = aten .slice .Tensor (
521
- x ._data , dim , packed_start , packed_end , step
521
+ x .qdata , dim , packed_start , packed_end , step
522
522
)
523
523
524
524
start_block = 0 if start is None else start // x ._block_size
@@ -531,9 +531,9 @@ def nvfp4_slice(func, types, args, kwargs):
531
531
532
532
# Create result tensor
533
533
result = NVFP4Tensor (
534
+ sliced_data ,
534
535
sliced_scale ,
535
536
x ._per_tensor_scale ,
536
- sliced_data ,
537
537
x ._block_size ,
538
538
x ._orig_dtype ,
539
539
x .mm_config ,
@@ -549,9 +549,9 @@ def nvfp4_t(func, types, args, kwargs):
549
549
# For now, only transpose(input, 0, 1) is supported.
550
550
old = args [0 ]
551
551
new = NVFP4Tensor (
552
+ old .qdata .t (),
552
553
old ._scale_e4m3 ,
553
554
old ._per_tensor_scale ,
554
- old ._data .t (),
555
555
old ._block_size ,
556
556
old ._orig_dtype ,
557
557
old .mm_config ,
@@ -563,14 +563,14 @@ def nvfp4_t(func, types, args, kwargs):
563
563
564
564
@implements ([aten .view .default ])
565
565
def nvfp4_view_op (func , types , args , kwargs ):
566
- data = args [0 ]._data
566
+ data = args [0 ].qdata
567
567
new_size = args [1 ]
568
568
new_size = tensor_size_hp_to_fp4x2 (new_size , data .is_contiguous ())
569
569
new_data = func (data , new_size , * args [2 :], ** kwargs )
570
570
return NVFP4Tensor (
571
+ new_data ,
571
572
args [0 ]._scale_e4m3 ,
572
573
args [0 ]._per_tensor_scale ,
573
- new_data ,
574
574
args [0 ]._block_size ,
575
575
args [0 ]._orig_dtype ,
576
576
args [0 ].mm_config ,
@@ -586,8 +586,8 @@ def _addmm_nvfp4_dispatch(
586
586
Core implementation shared between nvfp4_mm, nvfp4_addmm, and nvfp4_linear.
587
587
The only difference is whether bias is None or not.
588
588
"""
589
- assert a ._data .is_contiguous ()
590
- assert b ._data .t ().is_contiguous ()
589
+ assert a .qdata .is_contiguous ()
590
+ assert b .qdata .t ().is_contiguous ()
591
591
assert a ._block_size == 16 , f"NVFP4 requires block_size=16, got { a ._block_size } "
592
592
assert b ._block_size == 16 , f"NVFP4 requires block_size=16, got { b ._block_size } "
593
593
@@ -623,8 +623,8 @@ def _addmm_nvfp4_dispatch(
623
623
# should_add_bias_separately = bias is not None
624
624
625
625
result = torch ._scaled_mm (
626
- a ._data .view (torch .float4_e2m1fn_x2 ),
627
- b ._data .view (torch .float4_e2m1fn_x2 ),
626
+ a .qdata .view (torch .float4_e2m1fn_x2 ),
627
+ b .qdata .view (torch .float4_e2m1fn_x2 ),
628
628
a_scale_blocked .view (torch .float8_e4m3fn ),
629
629
b_scale_blocked .view (torch .float8_e4m3fn ),
630
630
bias = None if should_add_bias_separately else bias ,
0 commit comments