5
5
# LICENSE file in the root directory of this source tree.
6
6
7
7
import sys
8
+ from dataclasses import dataclass
8
9
from enum import Enum
9
10
from typing import Any , Dict , Optional
10
11
24
25
tensor_size_hp_to_fp4x2 ,
25
26
)
26
27
from torchao .prototype .mx_formats .utils import from_blocked , to_blocked
28
+ from torchao .quantization .quantize_ .common import (
29
+ QuantizeTensorKwargs ,
30
+ )
27
31
from torchao .utils import TorchAOBaseTensor , ceil_div , fill_defaults
28
32
29
33
E4M3_EPS = torch .finfo (torch .float8_e4m3fn ).tiny
@@ -38,6 +42,13 @@ class NVFP4MMConfig(Enum):
38
42
WEIGHT_ONLY = "weight_only"
39
43
40
44
45
+ @dataclass
46
+ class QuantizeTensorToNVFP4Kwargs (QuantizeTensorKwargs ):
47
+ block_size : int = 16
48
+ is_swizzled_scales : bool = False
49
+ use_triton_kernel : bool = False
50
+
51
+
41
52
# TODO(future PR): move over to TorchAOBaseTensor's dispatch
42
53
def implements (aten_ops ):
43
54
"""Register aten ops to the NVFP4 op table"""
@@ -60,33 +71,34 @@ class NVFP4Tensor(TorchAOBaseTensor):
60
71
qdata: Packed FP4 data (2 values per byte)
61
72
_scale_e4m3: Blockwise scales in float8_e4m3fn format (may be swizzled)
62
73
_per_tensor_scale: Optional global per-tensor scale in float32 format
74
+ _act_per_tensor_scale: Optional global per-tensor scale in float32 format, for activation
63
75
_block_size (int): Block size for quantization (fixed at 16)
64
76
_orig_dtype (torch.dtype): Original tensor dtype before quantization
65
77
_is_swizzled_scales (bool): Whether scales are stored in swizzled (blocked) format
66
- mm_config (NVFP4MMConfig): Matrix multiplication configuration
67
78
use_triton_kernel (bool): Whether to use triton kernels
68
79
"""
69
80
70
81
tensor_data_names = ["qdata" , "_scale_e4m3" ]
71
- optional_tensor_data_names = ["_per_tensor_scale" ]
82
+ optional_tensor_data_names = ["_per_tensor_scale" , "_act_per_tensor_scale" ]
72
83
tensor_attribute_names = [
73
84
"_block_size" ,
74
85
"_orig_dtype" ,
75
- "mm_config" ,
76
86
"_is_swizzled_scales" ,
77
87
"use_triton_kernel" ,
88
+ "act_quant_kwargs" ,
78
89
]
79
90
80
91
def __new__ (
81
92
cls ,
82
93
qdata ,
83
94
blockwise_scales ,
84
95
per_tensor_scale ,
96
+ act_per_tensor_scale ,
85
97
block_size ,
86
98
orig_dtype ,
87
- mm_config = NVFP4MMConfig .DYNAMIC ,
88
99
is_swizzled_scales = False ,
89
100
use_triton_kernel = False ,
101
+ act_quant_kwargs = None ,
90
102
):
91
103
# FP4 tensor size handling two paths, contiguous or not
92
104
new_size = qdata .size ()
@@ -107,11 +119,12 @@ def __new__(
107
119
self ._scale_e4m3 = blockwise_scales
108
120
self ._is_swizzled_scales = is_swizzled_scales
109
121
self ._per_tensor_scale = per_tensor_scale
122
+ self ._act_per_tensor_scale = act_per_tensor_scale
110
123
self .qdata = qdata
111
124
self ._block_size = block_size
112
125
self ._orig_dtype = orig_dtype
113
- self .mm_config = mm_config
114
126
self .use_triton_kernel = use_triton_kernel
127
+ self .act_quant_kwargs = act_quant_kwargs
115
128
return self
116
129
117
130
def __repr__ (self ):
@@ -130,9 +143,10 @@ def to_nvfp4(
130
143
data_hp : torch .Tensor ,
131
144
block_size : int = 16 ,
132
145
per_tensor_scale : Optional [torch .Tensor ] = None ,
133
- mm_config : NVFP4MMConfig = NVFP4MMConfig . DYNAMIC ,
146
+ act_per_tensor_scale : Optional [ torch . Tensor ] = None ,
134
147
is_swizzled_scales : bool = False ,
135
148
use_triton_kernel : bool = False ,
149
+ act_quant_kwargs : Optional [QuantizeTensorToNVFP4Kwargs ] = None ,
136
150
):
137
151
"""Convert high precision tensor to NVFP4 format.
138
152
@@ -141,9 +155,11 @@ def to_nvfp4(
141
155
block_size: Block size for quantization (must be 16)
142
156
per_tensor_scale: Optional pre-computed absolute maximum for calibration.
143
157
If provided, uses per-tensor scaling. If None, uses block-wise scaling only.
144
- mm_config: Matrix multiplication configuration
158
+ act_per_tensor_scale: Optional pre-computed absolute maximum for calibration for activation
159
+ If provided, uses per-tensor scaling. If None, uses block-wise scaling only.
145
160
is_swizzled_scales: If True, store scales in swizzled format for faster matrix multiplication
146
161
use_triton_kernel: If True, use Triton kernel for quantization
162
+ act_quant_kwargs: If specified, config for quantizing the activation
147
163
148
164
Returns:
149
165
NVFP4Tensor: Quantized tensor in NVFP4 format
@@ -169,11 +185,12 @@ def to_nvfp4(
169
185
data_lp ,
170
186
blockwise_scales ,
171
187
per_tensor_scale ,
188
+ act_per_tensor_scale ,
172
189
block_size ,
173
190
data_hp .dtype ,
174
- mm_config ,
175
191
is_swizzled_scales ,
176
192
use_triton_kernel ,
193
+ act_quant_kwargs ,
177
194
)
178
195
179
196
# Do not force the NVFP4Tensor type on the returned tensor
@@ -244,6 +261,9 @@ def _same_metadata(cls, self: "NVFP4Tensor", src: "NVFP4Tensor") -> bool:
244
261
per_tensor_scale_equal = (
245
262
self ._per_tensor_scale is None and src ._per_tensor_scale is None
246
263
) or (self ._per_tensor_scale .shape == src ._per_tensor_scale .shape )
264
+ act_per_tensor_scale_equal = (
265
+ self ._act_per_tensor_scale is None and src ._act_per_tensor_scale is None
266
+ ) or (self ._act_per_tensor_scale .shape == src ._act_per_tensor_scale .shape )
247
267
248
268
return (
249
269
isinstance (self , NVFP4Tensor )
@@ -253,7 +273,9 @@ def _same_metadata(cls, self: "NVFP4Tensor", src: "NVFP4Tensor") -> bool:
253
273
and self ._is_swizzled_scales == src ._is_swizzled_scales
254
274
and self ._scale_e4m3 .shape == src ._scale_e4m3 .shape
255
275
and per_tensor_scale_equal
276
+ and act_per_tensor_scale_equal
256
277
and self .qdata .shape == src .qdata .shape
278
+ and self .act_quant_kwargs == src .act_quant_kwargs
257
279
)
258
280
259
281
@@ -290,12 +312,13 @@ def nvfp4_to_copy(func, types, args, kwargs):
290
312
res = NVFP4Tensor (
291
313
tensor ._scale_e4m3 ,
292
314
tensor ._per_tensor_scale ,
315
+ tensor ._act_per_tensor_scale ,
293
316
tensor ._data ,
294
317
tensor ._block_size ,
295
318
dtype ,
296
- tensor .mm_config ,
297
319
tensor ._is_swizzled_scales ,
298
320
tensor .use_triton_kernel ,
321
+ tensor .act_quant_kwargs ,
299
322
)
300
323
return res
301
324
@@ -491,11 +514,12 @@ def nvfp4_slice(func, types, args, kwargs):
491
514
sliced_data ,
492
515
sliced_scale ,
493
516
x ._per_tensor_scale ,
517
+ x ._act_per_tensor_scale ,
494
518
x ._block_size ,
495
519
x ._orig_dtype ,
496
- x .mm_config ,
497
520
x ._is_swizzled_scales ,
498
521
x .use_triton_kernel ,
522
+ x .act_quant_kwargs ,
499
523
)
500
524
501
525
return return_and_correct_aliasing (func , args , kwargs , result )
@@ -509,11 +533,12 @@ def nvfp4_t(func, types, args, kwargs):
509
533
old .qdata .t (),
510
534
old ._scale_e4m3 ,
511
535
old ._per_tensor_scale ,
536
+ old ._act_per_tensor_scale ,
512
537
old ._block_size ,
513
538
old ._orig_dtype ,
514
- old .mm_config ,
515
539
old ._is_swizzled_scales ,
516
540
old .use_triton_kernel ,
541
+ old .act_quant_kwargs ,
517
542
)
518
543
return new
519
544
@@ -528,11 +553,12 @@ def nvfp4_view_op(func, types, args, kwargs):
528
553
new_data ,
529
554
args [0 ]._scale_e4m3 ,
530
555
args [0 ]._per_tensor_scale ,
556
+ args [0 ]._act_per_tensor_scale ,
531
557
args [0 ]._block_size ,
532
558
args [0 ]._orig_dtype ,
533
- args [0 ].mm_config ,
534
559
args [0 ]._is_swizzled_scales ,
535
560
args [0 ].use_triton_kernel ,
561
+ args [0 ].act_quant_kwargs ,
536
562
)
537
563
538
564
@@ -610,17 +636,19 @@ def nvfp4_linear(func, types, args, kwargs):
610
636
if not isinstance (weight_tensor , NVFP4Tensor ):
611
637
raise NotImplementedError ("NVFP4Tensor: weight must be NVFP4Tensor" )
612
638
613
- config = weight_tensor .mm_config
614
-
615
- if config == NVFP4MMConfig .WEIGHT_ONLY :
639
+ if weight_tensor .act_quant_kwargs is None :
640
+ # weight_only quant
616
641
weight_dequant = weight_tensor .to_dtype (weight_tensor ._orig_dtype )
617
642
return torch .nn .functional .linear (input_tensor , weight_dequant , bias )
618
643
else :
644
+ # dynamic quant
645
+ k = weight_tensor .act_quant_kwargs
619
646
input_tensor = NVFP4Tensor .to_nvfp4 (
620
647
input_tensor ,
621
- mm_config = config ,
622
- is_swizzled_scales = True ,
623
- use_triton_kernel = weight_tensor .use_triton_kernel ,
648
+ block_size = k .block_size ,
649
+ per_tensor_scale = weight_tensor ._act_per_tensor_scale ,
650
+ is_swizzled_scales = k .is_swizzled_scales ,
651
+ use_triton_kernel = k .use_triton_kernel ,
624
652
)
625
653
return _addmm_nvfp4_dispatch (input_tensor , weight_tensor .t (), func , bias = bias )
626
654
@@ -632,9 +660,7 @@ def nvfp4_mm(func, types, args, kwargs):
632
660
if not isinstance (weight_tensor , NVFP4Tensor ):
633
661
raise NotImplementedError ("NVFP4Tensor: weight must be NVFP4Tensor" )
634
662
635
- config = weight_tensor .mm_config
636
-
637
- if config == NVFP4MMConfig .WEIGHT_ONLY :
663
+ if weight_tensor .act_quant_kwargs is None :
638
664
weight_dequant = weight_tensor .to_dtype (weight_tensor ._orig_dtype )
639
665
if isinstance (input_tensor , NVFP4Tensor ):
640
666
input_dequant = input_tensor .to_dtype (input_tensor ._orig_dtype )
@@ -643,11 +669,13 @@ def nvfp4_mm(func, types, args, kwargs):
643
669
return func (input_tensor , weight_dequant )
644
670
else :
645
671
if not isinstance (input_tensor , NVFP4Tensor ):
672
+ k = weight_tensor .act_quant_kwargs
646
673
input_tensor = NVFP4Tensor .to_nvfp4 (
647
674
input_tensor ,
648
- mm_config = config ,
649
- is_swizzled_scales = True ,
650
- use_triton_kernel = weight_tensor .use_triton_kernel ,
675
+ block_size = k .block_size ,
676
+ per_tensor_scale = weight_tensor ._act_per_tensor_scale ,
677
+ is_swizzled_scales = k .is_swizzled_scales ,
678
+ use_triton_kernel = k .use_triton_kernel ,
651
679
)
652
680
return _addmm_nvfp4_dispatch (input_tensor , weight_tensor , func )
653
681
@@ -659,9 +687,7 @@ def nvfp4_addmm(func, types, args, kwargs):
659
687
if not isinstance (weight_tensor , NVFP4Tensor ):
660
688
raise NotImplementedError ("NVFP4Tensor: weight must be NVFP4Tensor" )
661
689
662
- config = weight_tensor .mm_config
663
-
664
- if config == NVFP4MMConfig .WEIGHT_ONLY :
690
+ if weight_tensor .act_quant_kwargs is None :
665
691
weight_dequant = weight_tensor .to_dtype (weight_tensor ._orig_dtype )
666
692
if isinstance (input_tensor , NVFP4Tensor ):
667
693
input_dequant = input_tensor .to_dtype (input_tensor ._orig_dtype )
@@ -670,11 +696,13 @@ def nvfp4_addmm(func, types, args, kwargs):
670
696
return torch .addmm (bias , input_tensor , weight_dequant )
671
697
else :
672
698
if not isinstance (input_tensor , NVFP4Tensor ):
699
+ k = weight_tensor .act_quant_kwargs
673
700
input_tensor = NVFP4Tensor .to_nvfp4 (
674
701
input_tensor ,
675
- mm_config = config ,
676
- is_swizzled_scales = True ,
677
- use_triton_kernel = weight_tensor .use_triton_kernel ,
702
+ block_size = k .block_size ,
703
+ per_tensor_scale = weight_tensor ._act_per_tensor_scale ,
704
+ is_swizzled_scales = k .is_swizzled_scales ,
705
+ use_triton_kernel = k .use_triton_kernel ,
678
706
)
679
707
return _addmm_nvfp4_dispatch (input_tensor , weight_tensor , func , bias = bias )
680
708
0 commit comments