Skip to content

Commit 5c0d6a3

Browse files
authored
nvfp4 tensor: refactor weight-only vs dynamic quant (#2790)
* Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned]
1 parent c120bb7 commit 5c0d6a3

File tree

4 files changed

+79
-38
lines changed

4 files changed

+79
-38
lines changed

test/prototype/mx_formats/test_mx_tensor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -916,8 +916,8 @@ def test_nvfp4_swizzled_scales_serialization():
916916
tensor_list, ctx = original_tensor.__tensor_flatten__()
917917

918918
# Verify swizzled flag is preserved in context
919-
assert NVFP4Tensor.tensor_attribute_names[3] == "_is_swizzled_scales"
920-
assert ctx[3] == True
919+
assert NVFP4Tensor.tensor_attribute_names[2] == "_is_swizzled_scales"
920+
assert ctx[2] == True
921921

922922
# Test deserialization
923923
inner_tensors = {}

test/prototype/mx_formats/test_nvfp4_tensor.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@
1515
from torchao.prototype.mx_formats.inference_workflow import (
1616
NVFP4MMConfig,
1717
)
18+
from torchao.prototype.mx_formats.nvfp4_tensor import (
19+
QuantizeTensorToNVFP4Kwargs,
20+
)
1821
from torchao.quantization.utils import compute_error
1922
from torchao.testing.utils import skip_if_rocm
2023
from torchao.utils import (
@@ -304,8 +307,8 @@ def test_nvfp4_swizzled_scales_serialization():
304307
tensor_list, ctx = original_tensor.__tensor_flatten__()
305308

306309
# Verify swizzled flag is preserved in context
307-
assert NVFP4Tensor.tensor_attribute_names[3] == "_is_swizzled_scales"
308-
assert ctx[3] == True
310+
assert NVFP4Tensor.tensor_attribute_names[2] == "_is_swizzled_scales"
311+
assert ctx[2] == True
309312

310313
# Test deserialization
311314
inner_tensors = {}
@@ -491,19 +494,21 @@ def test_nvfp4_matmul_with_amax(
491494

492495
a_scale = per_tensor_amax_to_scale(torch.amax(torch.abs(A)))
493496
b_scale = per_tensor_amax_to_scale(torch.amax(torch.abs(B)))
497+
act_quant_kwargs = None
498+
if mm_config == NVFP4MMConfig.DYNAMIC:
499+
act_quant_kwargs = QuantizeTensorToNVFP4Kwargs()
494500
A_nvfp4 = NVFP4Tensor.to_nvfp4(
495501
A,
496502
per_tensor_scale=a_scale,
497-
mm_config=mm_config,
498503
is_swizzled_scales=True,
499504
use_triton_kernel=use_triton_kernel,
500505
)
501506
B_nvfp4 = NVFP4Tensor.to_nvfp4(
502507
B,
503508
per_tensor_scale=b_scale,
504-
mm_config=mm_config,
505509
is_swizzled_scales=True,
506510
use_triton_kernel=use_triton_kernel,
511+
act_quant_kwargs=act_quant_kwargs,
507512
)
508513

509514
func = torch.compile(F.linear, fullgraph=True) if compile else F.linear

torchao/prototype/mx_formats/inference_workflow.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,11 @@
1919
_validate_gemm_kernel_choice,
2020
)
2121
from torchao.prototype.mx_formats.mx_tensor import MXTensor
22-
from torchao.prototype.mx_formats.nvfp4_tensor import NVFP4MMConfig, NVFP4Tensor
22+
from torchao.prototype.mx_formats.nvfp4_tensor import (
23+
NVFP4MMConfig,
24+
NVFP4Tensor,
25+
QuantizeTensorToNVFP4Kwargs,
26+
)
2327
from torchao.quantization.quant_api import to_linear_activation_quantized
2428
from torchao.quantization.transform_module import (
2529
register_quantize_module_handler,
@@ -199,11 +203,15 @@ def _nvfp4_inference_linear_transform(
199203
"Please use bfloat16 or float16 weights, or remove the bias from the linear layer."
200204
)
201205

206+
act_quant_kwargs = None
207+
if config.mm_config == NVFP4MMConfig.DYNAMIC:
208+
act_quant_kwargs = QuantizeTensorToNVFP4Kwargs()
209+
202210
quantized_weight = NVFP4Tensor.to_nvfp4(
203211
weight,
204-
mm_config=config.mm_config,
205212
is_swizzled_scales=True,
206213
use_triton_kernel=False, # Always use traditional construction for weights
214+
act_quant_kwargs=act_quant_kwargs,
207215
)
208216
# Set triton preference after construction
209217
quantized_weight.use_triton_kernel = config.use_triton_kernel

torchao/prototype/mx_formats/nvfp4_tensor.py

Lines changed: 58 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import sys
8+
from dataclasses import dataclass
89
from enum import Enum
910
from typing import Any, Dict, Optional
1011

@@ -24,6 +25,9 @@
2425
tensor_size_hp_to_fp4x2,
2526
)
2627
from torchao.prototype.mx_formats.utils import from_blocked, to_blocked
28+
from torchao.quantization.quantize_.common import (
29+
QuantizeTensorKwargs,
30+
)
2731
from torchao.utils import TorchAOBaseTensor, ceil_div, fill_defaults
2832

2933
E4M3_EPS = torch.finfo(torch.float8_e4m3fn).tiny
@@ -38,6 +42,13 @@ class NVFP4MMConfig(Enum):
3842
WEIGHT_ONLY = "weight_only"
3943

4044

45+
@dataclass
46+
class QuantizeTensorToNVFP4Kwargs(QuantizeTensorKwargs):
47+
block_size: int = 16
48+
is_swizzled_scales: bool = False
49+
use_triton_kernel: bool = False
50+
51+
4152
# TODO(future PR): move over to TorchAOBaseTensor's dispatch
4253
def implements(aten_ops):
4354
"""Register aten ops to the NVFP4 op table"""
@@ -60,33 +71,34 @@ class NVFP4Tensor(TorchAOBaseTensor):
6071
qdata: Packed FP4 data (2 values per byte)
6172
_scale_e4m3: Blockwise scales in float8_e4m3fn format (may be swizzled)
6273
_per_tensor_scale: Optional global per-tensor scale in float32 format
74+
_act_per_tensor_scale: Optional global per-tensor scale in float32 format, for activation
6375
_block_size (int): Block size for quantization (fixed at 16)
6476
_orig_dtype (torch.dtype): Original tensor dtype before quantization
6577
_is_swizzled_scales (bool): Whether scales are stored in swizzled (blocked) format
66-
mm_config (NVFP4MMConfig): Matrix multiplication configuration
6778
use_triton_kernel (bool): Whether to use triton kernels
6879
"""
6980

7081
tensor_data_names = ["qdata", "_scale_e4m3"]
71-
optional_tensor_data_names = ["_per_tensor_scale"]
82+
optional_tensor_data_names = ["_per_tensor_scale", "_act_per_tensor_scale"]
7283
tensor_attribute_names = [
7384
"_block_size",
7485
"_orig_dtype",
75-
"mm_config",
7686
"_is_swizzled_scales",
7787
"use_triton_kernel",
88+
"act_quant_kwargs",
7889
]
7990

8091
def __new__(
8192
cls,
8293
qdata,
8394
blockwise_scales,
8495
per_tensor_scale,
96+
act_per_tensor_scale,
8597
block_size,
8698
orig_dtype,
87-
mm_config=NVFP4MMConfig.DYNAMIC,
8899
is_swizzled_scales=False,
89100
use_triton_kernel=False,
101+
act_quant_kwargs=None,
90102
):
91103
# FP4 tensor size handling two paths, contiguous or not
92104
new_size = qdata.size()
@@ -107,11 +119,12 @@ def __new__(
107119
self._scale_e4m3 = blockwise_scales
108120
self._is_swizzled_scales = is_swizzled_scales
109121
self._per_tensor_scale = per_tensor_scale
122+
self._act_per_tensor_scale = act_per_tensor_scale
110123
self.qdata = qdata
111124
self._block_size = block_size
112125
self._orig_dtype = orig_dtype
113-
self.mm_config = mm_config
114126
self.use_triton_kernel = use_triton_kernel
127+
self.act_quant_kwargs = act_quant_kwargs
115128
return self
116129

117130
def __repr__(self):
@@ -130,9 +143,10 @@ def to_nvfp4(
130143
data_hp: torch.Tensor,
131144
block_size: int = 16,
132145
per_tensor_scale: Optional[torch.Tensor] = None,
133-
mm_config: NVFP4MMConfig = NVFP4MMConfig.DYNAMIC,
146+
act_per_tensor_scale: Optional[torch.Tensor] = None,
134147
is_swizzled_scales: bool = False,
135148
use_triton_kernel: bool = False,
149+
act_quant_kwargs: Optional[QuantizeTensorToNVFP4Kwargs] = None,
136150
):
137151
"""Convert high precision tensor to NVFP4 format.
138152
@@ -141,9 +155,11 @@ def to_nvfp4(
141155
block_size: Block size for quantization (must be 16)
142156
per_tensor_scale: Optional pre-computed absolute maximum for calibration.
143157
If provided, uses per-tensor scaling. If None, uses block-wise scaling only.
144-
mm_config: Matrix multiplication configuration
158+
act_per_tensor_scale: Optional pre-computed absolute maximum for calibration for activation
159+
If provided, uses per-tensor scaling. If None, uses block-wise scaling only.
145160
is_swizzled_scales: If True, store scales in swizzled format for faster matrix multiplication
146161
use_triton_kernel: If True, use Triton kernel for quantization
162+
act_quant_kwargs: If specified, config for quantizing the activation
147163
148164
Returns:
149165
NVFP4Tensor: Quantized tensor in NVFP4 format
@@ -169,11 +185,12 @@ def to_nvfp4(
169185
data_lp,
170186
blockwise_scales,
171187
per_tensor_scale,
188+
act_per_tensor_scale,
172189
block_size,
173190
data_hp.dtype,
174-
mm_config,
175191
is_swizzled_scales,
176192
use_triton_kernel,
193+
act_quant_kwargs,
177194
)
178195

179196
# Do not force the NVFP4Tensor type on the returned tensor
@@ -244,6 +261,9 @@ def _same_metadata(cls, self: "NVFP4Tensor", src: "NVFP4Tensor") -> bool:
244261
per_tensor_scale_equal = (
245262
self._per_tensor_scale is None and src._per_tensor_scale is None
246263
) or (self._per_tensor_scale.shape == src._per_tensor_scale.shape)
264+
act_per_tensor_scale_equal = (
265+
self._act_per_tensor_scale is None and src._act_per_tensor_scale is None
266+
) or (self._act_per_tensor_scale.shape == src._act_per_tensor_scale.shape)
247267

248268
return (
249269
isinstance(self, NVFP4Tensor)
@@ -253,7 +273,9 @@ def _same_metadata(cls, self: "NVFP4Tensor", src: "NVFP4Tensor") -> bool:
253273
and self._is_swizzled_scales == src._is_swizzled_scales
254274
and self._scale_e4m3.shape == src._scale_e4m3.shape
255275
and per_tensor_scale_equal
276+
and act_per_tensor_scale_equal
256277
and self.qdata.shape == src.qdata.shape
278+
and self.act_quant_kwargs == src.act_quant_kwargs
257279
)
258280

259281

@@ -290,12 +312,13 @@ def nvfp4_to_copy(func, types, args, kwargs):
290312
res = NVFP4Tensor(
291313
tensor._scale_e4m3,
292314
tensor._per_tensor_scale,
315+
tensor._act_per_tensor_scale,
293316
tensor._data,
294317
tensor._block_size,
295318
dtype,
296-
tensor.mm_config,
297319
tensor._is_swizzled_scales,
298320
tensor.use_triton_kernel,
321+
tensor.act_quant_kwargs,
299322
)
300323
return res
301324

@@ -491,11 +514,12 @@ def nvfp4_slice(func, types, args, kwargs):
491514
sliced_data,
492515
sliced_scale,
493516
x._per_tensor_scale,
517+
x._act_per_tensor_scale,
494518
x._block_size,
495519
x._orig_dtype,
496-
x.mm_config,
497520
x._is_swizzled_scales,
498521
x.use_triton_kernel,
522+
x.act_quant_kwargs,
499523
)
500524

501525
return return_and_correct_aliasing(func, args, kwargs, result)
@@ -509,11 +533,12 @@ def nvfp4_t(func, types, args, kwargs):
509533
old.qdata.t(),
510534
old._scale_e4m3,
511535
old._per_tensor_scale,
536+
old._act_per_tensor_scale,
512537
old._block_size,
513538
old._orig_dtype,
514-
old.mm_config,
515539
old._is_swizzled_scales,
516540
old.use_triton_kernel,
541+
old.act_quant_kwargs,
517542
)
518543
return new
519544

@@ -528,11 +553,12 @@ def nvfp4_view_op(func, types, args, kwargs):
528553
new_data,
529554
args[0]._scale_e4m3,
530555
args[0]._per_tensor_scale,
556+
args[0]._act_per_tensor_scale,
531557
args[0]._block_size,
532558
args[0]._orig_dtype,
533-
args[0].mm_config,
534559
args[0]._is_swizzled_scales,
535560
args[0].use_triton_kernel,
561+
args[0].act_quant_kwargs,
536562
)
537563

538564

@@ -610,17 +636,19 @@ def nvfp4_linear(func, types, args, kwargs):
610636
if not isinstance(weight_tensor, NVFP4Tensor):
611637
raise NotImplementedError("NVFP4Tensor: weight must be NVFP4Tensor")
612638

613-
config = weight_tensor.mm_config
614-
615-
if config == NVFP4MMConfig.WEIGHT_ONLY:
639+
if weight_tensor.act_quant_kwargs is None:
640+
# weight_only quant
616641
weight_dequant = weight_tensor.to_dtype(weight_tensor._orig_dtype)
617642
return torch.nn.functional.linear(input_tensor, weight_dequant, bias)
618643
else:
644+
# dynamic quant
645+
k = weight_tensor.act_quant_kwargs
619646
input_tensor = NVFP4Tensor.to_nvfp4(
620647
input_tensor,
621-
mm_config=config,
622-
is_swizzled_scales=True,
623-
use_triton_kernel=weight_tensor.use_triton_kernel,
648+
block_size=k.block_size,
649+
per_tensor_scale=weight_tensor._act_per_tensor_scale,
650+
is_swizzled_scales=k.is_swizzled_scales,
651+
use_triton_kernel=k.use_triton_kernel,
624652
)
625653
return _addmm_nvfp4_dispatch(input_tensor, weight_tensor.t(), func, bias=bias)
626654

@@ -632,9 +660,7 @@ def nvfp4_mm(func, types, args, kwargs):
632660
if not isinstance(weight_tensor, NVFP4Tensor):
633661
raise NotImplementedError("NVFP4Tensor: weight must be NVFP4Tensor")
634662

635-
config = weight_tensor.mm_config
636-
637-
if config == NVFP4MMConfig.WEIGHT_ONLY:
663+
if weight_tensor.act_quant_kwargs is None:
638664
weight_dequant = weight_tensor.to_dtype(weight_tensor._orig_dtype)
639665
if isinstance(input_tensor, NVFP4Tensor):
640666
input_dequant = input_tensor.to_dtype(input_tensor._orig_dtype)
@@ -643,11 +669,13 @@ def nvfp4_mm(func, types, args, kwargs):
643669
return func(input_tensor, weight_dequant)
644670
else:
645671
if not isinstance(input_tensor, NVFP4Tensor):
672+
k = weight_tensor.act_quant_kwargs
646673
input_tensor = NVFP4Tensor.to_nvfp4(
647674
input_tensor,
648-
mm_config=config,
649-
is_swizzled_scales=True,
650-
use_triton_kernel=weight_tensor.use_triton_kernel,
675+
block_size=k.block_size,
676+
per_tensor_scale=weight_tensor._act_per_tensor_scale,
677+
is_swizzled_scales=k.is_swizzled_scales,
678+
use_triton_kernel=k.use_triton_kernel,
651679
)
652680
return _addmm_nvfp4_dispatch(input_tensor, weight_tensor, func)
653681

@@ -659,9 +687,7 @@ def nvfp4_addmm(func, types, args, kwargs):
659687
if not isinstance(weight_tensor, NVFP4Tensor):
660688
raise NotImplementedError("NVFP4Tensor: weight must be NVFP4Tensor")
661689

662-
config = weight_tensor.mm_config
663-
664-
if config == NVFP4MMConfig.WEIGHT_ONLY:
690+
if weight_tensor.act_quant_kwargs is None:
665691
weight_dequant = weight_tensor.to_dtype(weight_tensor._orig_dtype)
666692
if isinstance(input_tensor, NVFP4Tensor):
667693
input_dequant = input_tensor.to_dtype(input_tensor._orig_dtype)
@@ -670,11 +696,13 @@ def nvfp4_addmm(func, types, args, kwargs):
670696
return torch.addmm(bias, input_tensor, weight_dequant)
671697
else:
672698
if not isinstance(input_tensor, NVFP4Tensor):
699+
k = weight_tensor.act_quant_kwargs
673700
input_tensor = NVFP4Tensor.to_nvfp4(
674701
input_tensor,
675-
mm_config=config,
676-
is_swizzled_scales=True,
677-
use_triton_kernel=weight_tensor.use_triton_kernel,
702+
block_size=k.block_size,
703+
per_tensor_scale=weight_tensor._act_per_tensor_scale,
704+
is_swizzled_scales=k.is_swizzled_scales,
705+
use_triton_kernel=k.use_triton_kernel,
678706
)
679707
return _addmm_nvfp4_dispatch(input_tensor, weight_tensor, func, bias=bias)
680708

0 commit comments

Comments
 (0)