Skip to content

Commit c09eff4

Browse files
committed
[Wave] Add quantized linear layer kernel
Signed-off-by: nithinsubbiah <[email protected]>
1 parent 52082b1 commit c09eff4

File tree

3 files changed

+440
-1
lines changed

3 files changed

+440
-1
lines changed
Lines changed: 283 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,283 @@
1+
# Copyright 2025 The IREE Authors
2+
#
3+
# Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
# See https://llvm.org/LICENSE.txt for license information.
5+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
7+
import torch
8+
from torch import nn
9+
from torch.testing import assert_close
10+
import math
11+
12+
import iree.turbine.kernel.lang as tkl
13+
import iree.turbine.kernel.wave as tkw
14+
from iree.turbine.kernel.lang.global_symbols import *
15+
from iree.turbine.kernel.wave.iree_utils import generate_iree_ref
16+
from iree.turbine.kernel.wave.utils.general_utils import (
17+
get_default_scheduling_params,
18+
)
19+
from iree.turbine.kernel.wave.utils.run_utils import (
20+
set_default_run_config,
21+
)
22+
from iree.turbine.kernel.wave.utils.mma_utils import (
23+
get_mfma_load_elems_per_thread,
24+
get_mfma_store_elems_per_thread,
25+
)
26+
from iree.turbine.kernel.wave.scheduling.schedule import SchedulingType
27+
from iree.turbine.kernel.wave.compile import WaveCompileOptions, wave_compile
28+
from iree.turbine.kernel.wave.constraints import MMAType
29+
from iree.turbine.kernel.wave.utils.general_utils import (
30+
torch_dtype_to_wave,
31+
torch_dtype_range,
32+
)
33+
34+
35+
def get_quant_linear_kernel(
36+
shape: tuple[int],
37+
quant_params,
38+
dynamic_dims: bool = False,
39+
mfma_variant: MMAType = MMAType.F32_16x16x32_F8,
40+
use_bias: bool = False,
41+
):
42+
# Input sizes
43+
B = tkl.sym.B
44+
M = tkl.sym.M
45+
N = tkl.sym.N
46+
K = tkl.sym.K
47+
# Workgroup tile sizes
48+
BLOCK_B = tkl.sym.BLOCK_B
49+
BLOCK_M = tkl.sym.BLOCK_M
50+
BLOCK_N = tkl.sym.BLOCK_N
51+
BLOCK_K = tkl.sym.BLOCK_K
52+
# Address space (for GPU, shared(1) or global(0))
53+
ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE
54+
# Other hyperparameters
55+
LOAD_ELEMS_PER_THREAD = tkl.sym.LOAD_ELEMS_PER_THREAD
56+
STORE_ELEMS_PER_THREAD = tkl.sym.STORE_ELEMS_PER_THREAD
57+
58+
# Expose user-constraints
59+
constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)]
60+
constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)]
61+
constraints += [tkw.WorkgroupConstraint(B, BLOCK_B, 2)]
62+
constraints += [tkw.TilingConstraint(K, BLOCK_K)]
63+
constraints += [tkw.WaveConstraint(M, BLOCK_M / 2)]
64+
constraints += [tkw.WaveConstraint(N, BLOCK_N / 2)]
65+
66+
constraints += [
67+
tkw.HardwareConstraint(
68+
threads_per_wave=64,
69+
waves_per_block=(2, 2, 1),
70+
vector_shapes={B: 0},
71+
mma_type=mfma_variant,
72+
)
73+
]
74+
75+
# With dynamic dimensions, we need to add an assumption on how big
76+
# the reduction dimension is to determine whether we can schedule or not.
77+
if dynamic_dims:
78+
constraints += [tkw.Assumption(K > BLOCK_K * 4)]
79+
80+
[weight_scale, input_scale, quant_dtype] = quant_params
81+
[qdtype_min, qdtype_max] = torch_dtype_range(quant_dtype)
82+
83+
def clamp_tensor(source_reg, lower_bound, upper_bound):
84+
clamped = tkw.minimum(source_reg, upper_bound)
85+
clamped = tkw.maximum(clamped, lower_bound)
86+
clamped = tkw.cast(clamped, torch_dtype_to_wave(quant_dtype))
87+
return clamped
88+
89+
# Wave-level micro-kernel.
90+
# Since warps are not directly addressable, there is no
91+
# explicit notion of a warp id (like a workgroup or thread id).
92+
# This kernel uses the input sizes M, N, K throughout, as the tiling
93+
# and data movement strategy is determined during the compilation process.
94+
# These can be influenced by introducing constraints.
95+
def gemm_core(a, b, c_reg, result):
96+
# TODO: Registers for quantization scaling of inputs. Remove once scalar
97+
# codegen is enabled.
98+
a_scale = tkl.Register[B, M, K, tkl.f16](1 / input_scale.item())
99+
a_clamp_max = tkl.Register[B, M, K, tkl.f16](qdtype_max)
100+
a_clamp_min = tkl.Register[B, M, K, tkl.f16](qdtype_min)
101+
b_scale = tkl.Register[N, K, tkl.f16](1 / weight_scale.item())
102+
b_clamp_max = tkl.Register[N, K, tkl.f16](qdtype_max)
103+
b_clamp_min = tkl.Register[N, K, tkl.f16](qdtype_min)
104+
a_scale_deq = tkl.Register[B, M, N, tkl.f32](input_scale.item())
105+
b_scale_deq = tkl.Register[B, M, N, tkl.f32](weight_scale.item())
106+
107+
@tkw.reduction(K, init_args=[c_reg])
108+
def repeat(
109+
acc: tkl.Register[B, M, N, tkl.f32],
110+
) -> tkl.Register[B, M, N, tkl.f32]:
111+
a_reg = tkw.read(a)
112+
b_reg = tkw.read(b)
113+
a_reg *= a_scale
114+
a_reg = clamp_tensor(a_reg, a_clamp_min, a_clamp_max)
115+
b_reg *= b_scale
116+
b_reg = clamp_tensor(b_reg, b_clamp_min, b_clamp_max)
117+
acc = tkw.mma(a_reg, b_reg, acc)
118+
acc *= a_scale_deq * b_scale_deq
119+
return acc
120+
121+
tkw.write(
122+
tkw.cast(repeat, tkl.f16),
123+
result,
124+
)
125+
126+
@tkw.wave(constraints)
127+
def gemm(
128+
a: tkl.Memory[B, M, K, ADDRESS_SPACE, tkl.f16],
129+
b: tkl.Memory[N, K, ADDRESS_SPACE, tkl.f16],
130+
result: tkl.Memory[B, M, N, GLOBAL_ADDRESS_SPACE, tkl.f16],
131+
):
132+
c_reg = tkl.Register[B, M, N, tkl.f32](0.0)
133+
gemm_core(a, b, c_reg, result)
134+
135+
@tkw.wave(constraints)
136+
def gemm_with_bias(
137+
a: tkl.Memory[B, M, K, ADDRESS_SPACE, tkl.f16],
138+
b: tkl.Memory[N, K, ADDRESS_SPACE, tkl.f16],
139+
bias: tkl.Memory[N, ADDRESS_SPACE, tkl.f16],
140+
result: tkl.Memory[B, M, N, GLOBAL_ADDRESS_SPACE, tkl.f16],
141+
):
142+
bias_reg = tkw.read(bias)
143+
bias_reg = tkw.broadcast(bias_reg, target_shape=[B, M, N])
144+
bias_reg = tkw.cast(bias_reg, tkl.f32)
145+
# We can get "free" bias-add by setting bias as the initial
146+
# value of accumulator to the mma
147+
gemm_core(a, b, bias_reg, result)
148+
149+
hyperparams = {
150+
ADDRESS_SPACE: SHARED_ADDRESS_SPACE,
151+
LOAD_ELEMS_PER_THREAD: get_mfma_load_elems_per_thread(mfma_variant),
152+
STORE_ELEMS_PER_THREAD: get_mfma_store_elems_per_thread(mfma_variant),
153+
BLOCK_B: 1,
154+
BLOCK_M: 64,
155+
BLOCK_N: 64,
156+
BLOCK_K: 32,
157+
N: shape[1],
158+
K: shape[0],
159+
}
160+
hyperparams.update(get_default_scheduling_params())
161+
162+
dynamic_symbols = [B, M]
163+
164+
options = WaveCompileOptions(
165+
subs=hyperparams,
166+
canonicalize=True,
167+
dynamic_symbols=dynamic_symbols,
168+
dynamic_symbols_map={},
169+
)
170+
options = set_default_run_config(options)
171+
gemm_kernel = gemm
172+
if use_bias:
173+
gemm_kernel = gemm_with_bias
174+
compiled_gemm = wave_compile(options, gemm_kernel)
175+
return compiled_gemm
176+
177+
178+
def extract_quant_params(quant_params: dict):
179+
weight_scale = torch.tensor(quant_params["weight_scale"]).view(
180+
quant_params["weight_scale_shape"]
181+
)
182+
input_scale = torch.tensor(quant_params["input_scale"]).view(
183+
quant_params["input_scale_shape"]
184+
)
185+
qdtype = quant_params["qdtype"]
186+
return weight_scale, input_scale, qdtype
187+
188+
189+
LINEAR_SUPPORTED_DTYPE = {torch.float16}
190+
191+
# Only per-tensor quantization is supported
192+
class WaveQuantLinear(nn.Module):
193+
"""Fork of nn.Linear implementation but modified to handle Wave Kernel"""
194+
195+
def __init__(
196+
self,
197+
in_features,
198+
out_features,
199+
quant_params,
200+
bias=True,
201+
device=None,
202+
dtype=None,
203+
):
204+
device = device or torch.device("cuda:0")
205+
dtype = dtype or torch.float16
206+
207+
if device.type != "cuda":
208+
raise ValueError(f"{self.__class__.__name__} only support GPU device.")
209+
if dtype not in LINEAR_SUPPORTED_DTYPE:
210+
raise ValueError(
211+
f"{self.__class__.__name__} does not support dtype: {dtype}."
212+
)
213+
factory_kwargs = {"device": device, "dtype": dtype}
214+
215+
super().__init__()
216+
self.in_features = in_features
217+
self.out_features = out_features
218+
self.weight = nn.Parameter(
219+
torch.empty((out_features, in_features), **factory_kwargs)
220+
)
221+
if bias:
222+
self.bias = nn.Parameter(torch.empty(out_features, **factory_kwargs))
223+
else:
224+
self.register_parameter("bias", None)
225+
self.reset_parameters()
226+
# Wave related initialization
227+
self.weight_scale, self.input_scale, self.qdtype = extract_quant_params(
228+
quant_params
229+
)
230+
if self.weight_scale.numel() != 1 or self.input_scale.numel() != 1:
231+
raise ValueError("Only per-tensor quantization is currently supported")
232+
self.kernel = get_quant_linear_kernel(
233+
[in_features, out_features],
234+
[self.weight_scale, self.input_scale, self.qdtype],
235+
use_bias=bias,
236+
)
237+
238+
def reset_parameters(self) -> None:
239+
# Setting a=sqrt(5) in kaiming_uniform is the same as initializing with
240+
# uniform(-1/sqrt(in_features), 1/sqrt(in_features)). For details, see
241+
# https://github.com/pytorch/pytorch/issues/57109
242+
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
243+
if self.bias is not None:
244+
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
245+
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
246+
nn.init.uniform_(self.bias, -bound, bound)
247+
248+
def forward(self, input: torch.Tensor) -> torch.Tensor:
249+
assert len(input.shape) >= 2
250+
# Determine parameter shapes
251+
input_len = input.shape[-2]
252+
batch = input.shape[0:-2]
253+
254+
# Compute "flattened" batch shapes
255+
flat_batch = math.prod(batch)
256+
out_features = self.weight.shape[0]
257+
output_shape = [flat_batch, input_len, out_features]
258+
259+
# Setup and run kernel
260+
output = torch.empty(
261+
output_shape, dtype=self.weight.dtype, device=self.weight.device
262+
)
263+
self.kernel.options.dynamic_symbols_map = {
264+
tkl.sym.B: flat_batch,
265+
tkl.sym.M: input_len,
266+
}
267+
if self.bias is None:
268+
self.kernel(
269+
input.view(flat_batch, input_len, input.shape[-1]), self.weight, output
270+
)
271+
else:
272+
self.kernel(
273+
input.view(flat_batch, input_len, input.shape[-1]),
274+
self.weight,
275+
self.bias,
276+
output,
277+
)
278+
279+
# Return non flattened shape
280+
return output.view(*batch, input_len, out_features)
281+
282+
def extra_repr(self) -> str:
283+
return f"in_features={self.in_features}, out_features={self.out_features}, bias={self.bias is not None}"

iree/turbine/kernel/wave/utils/general_utils.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,7 @@ def partial(func, *args, **kwargs):
317317
torch.float8_e5m2: tkl.f8e5m2,
318318
torch.float8_e5m2fnuz: tkl.f8e5m2fnuz,
319319
torch.float8_e4m3fn: tkl.f8e4m3fn,
320-
torch.torch.float8_e4m3fnuz: tkl.f8e4m3fnuz,
320+
torch.float8_e4m3fnuz: tkl.f8e4m3fnuz,
321321
torch.float16: tkl.f16,
322322
torch.float32: tkl.f32,
323323
torch.float64: tkl.f64,
@@ -327,6 +327,20 @@ def partial(func, *args, **kwargs):
327327
torch.bool: tkl.bool,
328328
}
329329

330+
TORCH_DTYPE_RANGE = {
331+
torch.bfloat16: [-3.3895313892515355e38, 3.3895313892515355e38],
332+
torch.float8_e5m2: [-57344.0, 57344.0],
333+
torch.float8_e5m2fnuz: [-57344.0, 57344.0],
334+
torch.float8_e4m3fn: [-448.0, 448.0],
335+
torch.float8_e4m3fnuz: [-240.0, 240.0],
336+
torch.float16: [-65504.0, 65504.0],
337+
torch.float32: [-3.4028234663852886e38, 3.4028234663852886e38],
338+
torch.float64: [-1.7976931348623157e308, 1.7976931348623157e308],
339+
torch.int16: [-32768, 32767],
340+
torch.int32: [-2147483648, 2147483647],
341+
torch.int64: [-9223372036854775808, 9223372036854775807],
342+
}
343+
330344

331345
def torch_dtype_to_wave(torch_dtype: torch.dtype) -> Any:
332346
try:
@@ -335,6 +349,13 @@ def torch_dtype_to_wave(torch_dtype: torch.dtype) -> Any:
335349
raise ValueError(f"Unable to map torch dtype {torch_dtype} to Wave.")
336350

337351

352+
def torch_dtype_range(torch_dtype: torch.dtype) -> Any:
353+
try:
354+
return TORCH_DTYPE_RANGE[torch_dtype]
355+
except KeyError:
356+
raise ValueError(f"Unable to retrieve torch dtype {torch_dtype} range.")
357+
358+
338359
def is_shared_write(node: CustomOp) -> bool:
339360
return (
340361
isinstance(node, Write)

0 commit comments

Comments
 (0)