Skip to content

Commit 0d8be61

Browse files
authored
Add fp32 <-> mx4 quantization operators (#446) (#446)
1 parent b64c34e commit 0d8be61

File tree

4 files changed

+101
-0
lines changed

4 files changed

+101
-0
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .operator import Operator
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
import argparse
2+
from typing import Callable, Generator, List, Optional, Tuple
3+
4+
import torch
5+
6+
# We are benchmarking the kernel used inside quantize_comm. Insofar, we are using the fp32_to_mx4 fbgemm API rather than the quantize_mx API.
7+
from fbgemm_gpu.quantize_utils import fp32_to_mx4, RoundingMode
8+
9+
from tritonbench.utils.triton_op import (
10+
BenchmarkOperator,
11+
register_benchmark,
12+
register_x_val,
13+
)
14+
15+
16+
class Operator(BenchmarkOperator):
17+
def __init__(
18+
self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None
19+
):
20+
super().__init__(tb_args, extra_args)
21+
# they are generated later
22+
self.reset_dynamo = True
23+
24+
def get_input_iter(self) -> Generator:
25+
for sz in [24048, 1024 * 1024, 64 * 1024 * 1024, 64 * 1024 * 1024 + 16]:
26+
_input = torch.randn((sz,), device=self.device, dtype=torch.float32)
27+
yield _input, 32, 2, 1, RoundingMode.even, False
28+
29+
@register_benchmark(baseline=True, fwd_only=True)
30+
def fbgemm_fp32_to_mx4(self, *args) -> Callable:
31+
return lambda: fp32_to_mx4(*args, use_triton=True)
32+
33+
@register_x_val(
34+
label="(Size, Group Size, ebits, mbits, rounding_mode, stochastic_casting)"
35+
)
36+
def get_x_val(self, example_inputs) -> Tuple[int, int, int, int, RoundingMode, int]:
37+
input_tensor, group_size, ebits, mbits, rounding_mode, stochastic_casting = (
38+
example_inputs
39+
)
40+
return (
41+
input_tensor.numel(),
42+
group_size,
43+
ebits,
44+
mbits,
45+
rounding_mode,
46+
stochastic_casting,
47+
)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .operator import Operator
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
import argparse
2+
from typing import Callable, Generator, List, Optional, Tuple
3+
4+
import torch
5+
6+
# We are benchmarking the kernel used inside quantize_comm. Insofar, we are using the fp32_to_mx4 fbgemm API rather than the quantize_mx API.
7+
from fbgemm_gpu.quantize_utils import fp32_to_mx4, mx4_to_fp32
8+
9+
from tritonbench.utils.triton_op import (
10+
BenchmarkOperator,
11+
register_benchmark,
12+
register_x_val,
13+
)
14+
15+
16+
class Operator(BenchmarkOperator):
17+
def __init__(
18+
self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None
19+
):
20+
super().__init__(tb_args, extra_args)
21+
# they are generated later
22+
self.reset_dynamo = True
23+
24+
def get_input_iter(self) -> Generator:
25+
for sz in [12024, 512 * 1024, 32 * 1024 * 1024, 32 * 1024 * 1024 + 16]:
26+
ebits = 2
27+
mbits = 1
28+
group_size = 32
29+
_input = fp32_to_mx4(
30+
torch.randn((sz,), device=self.device, dtype=torch.float32),
31+
group_size,
32+
ebits,
33+
mbits,
34+
)
35+
yield _input, group_size, ebits, mbits
36+
37+
@register_benchmark(baseline=True, fwd_only=True)
38+
def fbgemm_mx4_to_fp32(
39+
self, tensor: torch.Tensor, group_size: int, ebits: int, mbits: int
40+
) -> Callable:
41+
return lambda: mx4_to_fp32(
42+
tensor=tensor,
43+
group_size=group_size,
44+
use_triton=True,
45+
ebits=ebits,
46+
mbits=mbits,
47+
)
48+
49+
@register_x_val(label="(Size, Group Size, ebits, mbits)")
50+
def get_x_val(self, example_inputs) -> Tuple[int, int, int, int]:
51+
input_tensor, group_size, ebits, mbits = example_inputs
52+
return (input_tensor.numel(), group_size, ebits, mbits)

0 commit comments

Comments
 (0)