Skip to content

Commit 344d201

Browse files
authored
delete outdated MX inference code (#2615)
* Update [ghstack-poisoned] * Update [ghstack-poisoned]
1 parent bf5bd5f commit 344d201

File tree

5 files changed

+1
-190
lines changed

5 files changed

+1
-190
lines changed

test/prototype/mx_formats/test_mx_linear.py

Lines changed: 0 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,14 @@
1414
from torchao.prototype.mx_formats.config import (
1515
MXFP8Dim1CastKernelChoice,
1616
MXGemmKernelChoice,
17-
MXInferenceLinearConfig,
1817
MXLinearConfig,
1918
MXLinearRecipeName,
2019
)
2120
from torchao.prototype.mx_formats.constants import (
2221
DTYPE_FP6_E2M3,
2322
DTYPE_FP6_E3M2,
24-
SUPPORTED_ELEM_DTYPES,
2523
)
2624
from torchao.prototype.mx_formats.mx_linear import (
27-
MXInferenceLinear,
2825
MXLinear,
2926
)
3027
from torchao.prototype.mx_formats.mx_subclass import (
@@ -313,77 +310,18 @@ def test_linear_compile(hp_dtype, recipe_name, bias, mxfp8_cast_kernel_choice):
313310
torch.testing.assert_close(x_g_ref, x_g, atol=0.02, rtol=0.02)
314311

315312

316-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
317-
@pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES)
318-
@pytest.mark.parametrize("bias", [True, False])
319-
@pytest.mark.parametrize("input_shape", [(2, 4), (1, 2, 4), (1, 1, 2, 4)])
320-
def test_inference_linear(elem_dtype, bias, input_shape):
321-
"""
322-
Smoke test for inference linear module with mx weight
323-
"""
324-
m = nn.Sequential(nn.Linear(4, 8, bias=bias, dtype=torch.bfloat16))
325-
m = m.cuda()
326-
m_mx = copy.deepcopy(m)
327-
config = MXInferenceLinearConfig(block_size=4, elem_dtype=elem_dtype)
328-
quantize_(m_mx, config=config)
329-
330-
x = torch.randn(*input_shape, device="cuda", dtype=torch.bfloat16)
331-
y_ref = m(x)
332-
y_mx = m_mx(x)
333-
sqnr = compute_error(y_ref, y_mx)
334-
if elem_dtype is torch.float8_e4m3fn:
335-
assert sqnr >= 20.0
336-
else:
337-
assert sqnr >= 11.0
338-
339-
340-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
341-
@pytest.mark.skipif(
342-
not TORCH_VERSION_AT_LEAST_2_8, reason="torch.compile requires PyTorch 2.8+"
343-
)
344-
@pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES)
345-
def test_inference_compile_simple(elem_dtype):
346-
"""
347-
Smoke test for inference compile
348-
"""
349-
if elem_dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
350-
if not is_sm_at_least_89():
351-
pytest.skip("CUDA capability >= 8.9 required for float8 in triton")
352-
m = nn.Sequential(nn.Linear(4, 8, bias=False, dtype=torch.bfloat16))
353-
m = m.cuda()
354-
m_mx = copy.deepcopy(m)
355-
config = MXInferenceLinearConfig(block_size=4, elem_dtype=elem_dtype)
356-
quantize_(m_mx, config=config)
357-
m_mx = torch.compile(m_mx, fullgraph="true")
358-
359-
x = torch.randn(2, 4, device="cuda", dtype=torch.bfloat16)
360-
y_ref = m(x)
361-
y_mx = m_mx(x)
362-
sqnr = compute_error(y_ref, y_mx)
363-
if elem_dtype is torch.float8_e4m3fn:
364-
assert sqnr >= 20.0
365-
else:
366-
assert sqnr >= 11.5
367-
368-
369313
def test_filter_fn():
370314
m1 = nn.Sequential(
371315
nn.Linear(32, 32),
372316
nn.Linear(32, 32),
373317
)
374-
m2 = copy.deepcopy(m1)
375318
filter_fn = lambda mod, fqn: isinstance(mod, torch.nn.Linear) and fqn != "1" # noqa: E731
376319

377320
config = MXLinearConfig(block_size=32)
378321
quantize_(m1, config=config, filter_fn=filter_fn)
379322
assert type(m1[0]) == MXLinear
380323
assert type(m1[1]) == torch.nn.Linear
381324

382-
config2 = MXInferenceLinearConfig(block_size=32)
383-
quantize_(m2, config=config2, filter_fn=filter_fn) # noqa: E501
384-
assert type(m2[0]) == MXInferenceLinear
385-
assert type(m2[1]) == torch.nn.Linear
386-
387325

388326
def test_training_print_str():
389327
m = nn.Sequential(nn.Linear(32, 32))
@@ -394,15 +332,6 @@ def test_training_print_str():
394332
assert "kernel=emulated" in s
395333

396334

397-
def test_inference_print_str():
398-
m = nn.Sequential(nn.Linear(32, 32))
399-
config = MXInferenceLinearConfig()
400-
quantize_(m, config=config)
401-
s = str(m)
402-
assert "bl_sz=32" in s
403-
assert "kernel=emulated" in s
404-
405-
406335
test_dtypes = (
407336
[torch.float8_e4m3fn, torch.float4_e2m1fn_x2]
408337
if TORCH_VERSION_AT_LEAST_2_8

torchao/prototype/mx_formats/README.md

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -45,24 +45,8 @@ quantize_(m, config)
4545

4646
## MX inference
4747

48-
Note: currently only weight-only quantization is supported.
49-
50-
```python
51-
import torch
52-
from torchao.quantization import quantize_
53-
from torchao.prototype.mx_formats import MXInferenceLinearConfig, MXGemmKernelChoice
54-
55-
m = torch.nn.Sequential(torch.nn.Linear(32, 32)).cuda()
56-
gemm_kernel_choice = MXGemmKernelChoice.CUBLAS
57-
config = MXInferenceLinearConfig(
58-
elem_dtype=torch.float8_e4m3fn,
59-
block_size=32,
60-
gemm_kernel_choice=gemm_kernel_choice,
61-
)
62-
quantize_(m, config=config)
48+
Coming soon!
6349

64-
# do inference (not shown)
65-
```
6650
## MXTensor
6751

6852
This is casts between high precision and MX formats implemented in native PyTorch. Currently

torchao/prototype/mx_formats/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from torchao.prototype.mx_formats.config import (
22
MXGemmKernelChoice,
3-
MXInferenceLinearConfig,
43
MXLinearConfig,
54
MXLinearRecipeName,
65
)
@@ -18,7 +17,6 @@
1817

1918
__all__ = [
2019
"MXGemmKernelChoice",
21-
"MXInferenceLinearConfig",
2220
"MXLinearConfig",
2321
"MXLinearRecipeName",
2422
"MXFPInferenceConfig",

torchao/prototype/mx_formats/config.py

Lines changed: 0 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,6 @@
1212

1313
from torchao.core.config import AOBaseConfig
1414
from torchao.prototype.mx_formats.constants import (
15-
DTYPE_FP6_E2M3,
16-
DTYPE_FP6_E3M2,
1715
DTYPE_TO_SHORT_STR,
1816
SUPPORTED_ELEM_DTYPES,
1917
)
@@ -163,46 +161,3 @@ def short_str(self) -> str:
163161
if self.use_fp4_custom_triton_dequant_kernel:
164162
s += ", use_fp4_custom_triton_dequant_kernel=True"
165163
return s
166-
167-
168-
@dataclass
169-
class MXInferenceLinearConfig(AOBaseConfig):
170-
# block size for scaling, default is 32 to match
171-
# https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf,
172-
# section 5.2
173-
block_size: int = 32
174-
175-
# element dtype, used for activations, weights and gradients
176-
elem_dtype: Any = torch.float8_e4m3fn
177-
# TODO(future PR): support different elem_dtype for activations vs weights
178-
179-
# defines the gemm kernel choice, if the chosen kernel is not supported
180-
# on the given hardware an exception will be thrown
181-
gemm_kernel_choice: MXGemmKernelChoice = MXGemmKernelChoice.EMULATED
182-
183-
# If True, uses a custom triton kernel for fp4 dequantize
184-
use_fp4_custom_triton_dequant_kernel: bool = False
185-
186-
# If True, packs 4xFP6 into 3xuint8 containers for inference, using custom triton
187-
# kernels (fused unpack/dequantize).
188-
pack_fp6: bool = True
189-
190-
def __post_init__(self):
191-
_validate_elem_dtype(self.elem_dtype)
192-
_validate_gemm_kernel_choice(
193-
self.gemm_kernel_choice, self.block_size, self.elem_dtype
194-
)
195-
196-
def short_str(self) -> str:
197-
"""
198-
Returns a concise representation of the current config.
199-
"""
200-
s = f"bl_sz={self.block_size}, lp_dtype={DTYPE_TO_SHORT_STR[self.elem_dtype]}"
201-
s += f", kernel={self.gemm_kernel_choice.value}"
202-
if self.use_fp4_custom_triton_dequant_kernel:
203-
s += ", use_fp4_custom_triton_dequant_kernel=True"
204-
if self.elem_dtype in (DTYPE_FP6_E2M3, DTYPE_FP6_E3M2) and self.pack_fp6:
205-
s += ", pack_fp6=True"
206-
return s
207-
208-
# TODO(future PR): add a recipe to config API for inference

torchao/prototype/mx_formats/mx_linear.py

Lines changed: 0 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,11 @@
1111
from typing import Any, Optional
1212

1313
import torch
14-
import torch.nn.functional as F
1514
from torch.distributed._tensor import DTensor
1615

1716
from torchao.prototype.mx_formats.config import (
1817
MXFP8Dim1CastKernelChoice,
1918
MXGemmKernelChoice,
20-
MXInferenceLinearConfig,
2119
MXLinearConfig,
2220
)
2321
from torchao.prototype.mx_formats.kernels import (
@@ -270,59 +268,6 @@ def extra_repr(self):
270268
return s
271269

272270

273-
class MXInferenceLinear(torch.nn.Linear):
274-
"""
275-
Inference version of MXLinear, with the weight pre-quantized to MX.
276-
277-
Note: this is weight-only quantization, with the gemm being executed
278-
in high precision.
279-
"""
280-
281-
@classmethod
282-
@torch.no_grad()
283-
def from_float(
284-
cls,
285-
mod,
286-
config: Optional[MXInferenceLinearConfig] = MXInferenceLinearConfig(),
287-
):
288-
with torch.device("meta"):
289-
super_kwargs = {
290-
"in_features": mod.in_features,
291-
"out_features": mod.out_features,
292-
"bias": False,
293-
}
294-
new_mod = cls(**super_kwargs)
295-
# TODO(future PR): set to new_mod.weight directly, will need to work
296-
# through some errors
297-
new_mod.weight_mx = MXTensor.to_mx(
298-
mod.weight,
299-
config.elem_dtype,
300-
block_size=config.block_size,
301-
gemm_kernel_choice=config.gemm_kernel_choice,
302-
pack_fp6=config.pack_fp6,
303-
)
304-
new_mod.bias = mod.bias
305-
new_mod.config = config
306-
return new_mod
307-
308-
@torch.no_grad()
309-
def forward(self, x):
310-
w_hp = self.weight_mx.to_dtype(x.dtype)
311-
y = F.linear(x, w_hp, self.bias)
312-
return y
313-
314-
def extra_repr(self):
315-
s = f"{super().extra_repr()}, {self.config.short_str()}"
316-
return s
317-
318-
319271
@register_quantize_module_handler(MXLinearConfig)
320272
def _mx_linear_transform(module: torch.nn.Module, config: MXLinearConfig):
321273
return MXLinear.from_float(module, config=config)
322-
323-
324-
@register_quantize_module_handler(MXInferenceLinearConfig)
325-
def _mx_inference_linear_transform(
326-
module: torch.nn.Module, config: MXInferenceLinearConfig
327-
):
328-
return MXInferenceLinear.from_float(module, config=config)

0 commit comments

Comments
 (0)