6
6
7
7
import sys
8
8
from enum import Enum
9
- from typing import Any , Callable , Dict , Optional
9
+ from typing import Any , Dict , Optional
10
10
11
11
import torch
12
12
from torch .utils ._python_dispatch import return_and_correct_aliasing
24
24
tensor_size_hp_to_fp4x2 ,
25
25
)
26
26
from torchao .prototype .mx_formats .utils import from_blocked , to_blocked
27
- from torchao .utils import ceil_div , fill_defaults
27
+ from torchao .utils import TorchAOBaseTensor , ceil_div , fill_defaults
28
28
29
29
E4M3_EPS = torch .finfo (torch .float8_e4m3fn ).tiny
30
30
@@ -38,6 +38,7 @@ class NVFP4MMConfig(Enum):
38
38
WEIGHT_ONLY = "weight_only"
39
39
40
40
41
+ # TODO(future PR): move over to TorchAOBaseTensor's dispatch
41
42
def implements (aten_ops ):
42
43
"""Register aten ops to the NVFP4 op table"""
43
44
@@ -49,7 +50,7 @@ def decorator(func):
49
50
return decorator
50
51
51
52
52
- class NVFP4Tensor (torch . Tensor ):
53
+ class NVFP4Tensor (TorchAOBaseTensor ):
53
54
"""NVIDIA FP4 (NVFP4) Tensor subclass.
54
55
55
56
This implements the NVIDIA variant of MX FP4 format, which uses a specific
@@ -59,20 +60,22 @@ class NVFP4Tensor(torch.Tensor):
59
60
qdata: Packed FP4 data (2 values per byte)
60
61
_scale_e4m3: Blockwise scales in float8_e4m3fn format (may be swizzled)
61
62
_per_tensor_scale: Optional global per-tensor scale in float32 format
62
- _block_size: Block size for quantization (fixed at 16)
63
- _orig_dtype: Original tensor dtype before quantization
64
- _is_swizzled_scales: Whether scales are stored in swizzled (blocked) format
65
- mm_config: Matrix multiplication configuration
63
+ _block_size (int): Block size for quantization (fixed at 16)
64
+ _orig_dtype (torch.dtype): Original tensor dtype before quantization
65
+ _is_swizzled_scales (bool): Whether scales are stored in swizzled (blocked) format
66
+ mm_config (NVFP4MMConfig): Matrix multiplication configuration
67
+ use_triton_kernel (bool): Whether to use triton kernels
66
68
"""
67
69
68
- qdata : torch .Tensor
69
- _scale_e4m3 : torch .Tensor
70
- _per_tensor_scale : Optional [torch .Tensor ]
71
- _block_size : int
72
- _orig_dtype : torch .dtype
73
- _is_swizzled_scales : bool
74
- mm_config : NVFP4MMConfig
75
- use_triton_kernel : bool
70
+ tensor_data_names = ["qdata" , "_scale_e4m3" ]
71
+ optional_tensor_data_names = ["_per_tensor_scale" ]
72
+ tensor_attribute_names = [
73
+ "_block_size" ,
74
+ "_orig_dtype" ,
75
+ "mm_config" ,
76
+ "_is_swizzled_scales" ,
77
+ "use_triton_kernel" ,
78
+ ]
76
79
77
80
def __new__ (
78
81
cls ,
@@ -173,52 +176,6 @@ def to_nvfp4(
173
176
use_triton_kernel ,
174
177
)
175
178
176
- def __tensor_flatten__ (self ):
177
- ctx = {
178
- "_block_size" : self ._block_size ,
179
- "_orig_dtype" : self ._orig_dtype ,
180
- "_is_swizzled_scales" : self ._is_swizzled_scales ,
181
- "mm_config" : self .mm_config ,
182
- "use_triton_kernel" : self .use_triton_kernel ,
183
- }
184
- tensor_list = ["qdata" , "_scale_e4m3" ]
185
- if self ._per_tensor_scale is not None :
186
- tensor_list .append ("_per_tensor_scale" )
187
- return tensor_list , ctx
188
-
189
- def _apply_fn_to_data (self , fn : Callable ):
190
- """Applies a fn to all tensor components stored on this class"""
191
- tensor_names , ctx = self .__tensor_flatten__ ()
192
- new_tensors = {}
193
- for name in tensor_names :
194
- new_tensors [name ] = fn (getattr (self , name ))
195
- if "_per_tensor_scale" not in tensor_names :
196
- new_tensors ["_per_tensor_scale" ] = None
197
- return self .__class__ .__tensor_unflatten__ (
198
- new_tensors ,
199
- ctx ,
200
- None ,
201
- None ,
202
- )
203
-
204
- @staticmethod
205
- def __tensor_unflatten__ (
206
- inner_tensors ,
207
- metadata ,
208
- outer_size ,
209
- outer_stride ,
210
- ):
211
- return NVFP4Tensor (
212
- inner_tensors ["qdata" ],
213
- inner_tensors ["_scale_e4m3" ],
214
- inner_tensors .get ("_per_tensor_scale" , None ),
215
- metadata ["_block_size" ],
216
- metadata ["_orig_dtype" ],
217
- metadata ["mm_config" ],
218
- metadata .get ("_is_swizzled_scales" , False ),
219
- metadata .get ("use_triton_kernel" , False ),
220
- )
221
-
222
179
# Do not force the NVFP4Tensor type on the returned tensor
223
180
__torch_function__ = torch ._C ._disabled_torch_function_impl
224
181
0 commit comments