|
| 1 | +# Copyright 2025 The IREE Authors |
| 2 | +# |
| 3 | +# Licensed under the Apache License v2.0 with LLVM Exceptions. |
| 4 | +# See https://llvm.org/LICENSE.txt for license information. |
| 5 | +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | + |
| 7 | +import torch |
| 8 | +from torch import nn |
| 9 | +from torch.testing import assert_close |
| 10 | +import math |
| 11 | + |
| 12 | +import iree.turbine.kernel.lang as tkl |
| 13 | +import iree.turbine.kernel.wave as tkw |
| 14 | +from iree.turbine.kernel.lang.global_symbols import * |
| 15 | +from iree.turbine.kernel.wave.iree_utils import generate_iree_ref |
| 16 | +from iree.turbine.kernel.wave.utils.general_utils import ( |
| 17 | + get_default_scheduling_params, |
| 18 | +) |
| 19 | +from iree.turbine.kernel.wave.utils.run_utils import ( |
| 20 | + set_default_run_config, |
| 21 | +) |
| 22 | +from iree.turbine.kernel.wave.utils.mma_utils import ( |
| 23 | + get_mfma_load_elems_per_thread, |
| 24 | + get_mfma_store_elems_per_thread, |
| 25 | +) |
| 26 | +from iree.turbine.kernel.wave.scheduling.schedule import SchedulingType |
| 27 | +from iree.turbine.kernel.wave.compile import WaveCompileOptions, wave_compile |
| 28 | +from iree.turbine.kernel.wave.constraints import MMAType |
| 29 | +from iree.turbine.kernel.wave.utils.general_utils import ( |
| 30 | + torch_dtype_to_wave, |
| 31 | + torch_dtype_range, |
| 32 | +) |
| 33 | + |
| 34 | + |
| 35 | +def get_quant_linear_kernel( |
| 36 | + shape: tuple[int], |
| 37 | + quant_params, |
| 38 | + dynamic_dims: bool = False, |
| 39 | + mfma_variant: MMAType = MMAType.F32_16x16x32_F8, |
| 40 | + use_bias: bool = False, |
| 41 | +): |
| 42 | + # Input sizes |
| 43 | + B = tkl.sym.B |
| 44 | + M = tkl.sym.M |
| 45 | + N = tkl.sym.N |
| 46 | + K = tkl.sym.K |
| 47 | + # Workgroup tile sizes |
| 48 | + BLOCK_B = tkl.sym.BLOCK_B |
| 49 | + BLOCK_M = tkl.sym.BLOCK_M |
| 50 | + BLOCK_N = tkl.sym.BLOCK_N |
| 51 | + BLOCK_K = tkl.sym.BLOCK_K |
| 52 | + # Address space (for GPU, shared(1) or global(0)) |
| 53 | + ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE |
| 54 | + # Other hyperparameters |
| 55 | + LOAD_ELEMS_PER_THREAD = tkl.sym.LOAD_ELEMS_PER_THREAD |
| 56 | + STORE_ELEMS_PER_THREAD = tkl.sym.STORE_ELEMS_PER_THREAD |
| 57 | + |
| 58 | + # Expose user-constraints |
| 59 | + constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)] |
| 60 | + constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)] |
| 61 | + constraints += [tkw.WorkgroupConstraint(B, BLOCK_B, 2)] |
| 62 | + constraints += [tkw.TilingConstraint(K, BLOCK_K)] |
| 63 | + constraints += [tkw.WaveConstraint(M, BLOCK_M / 2)] |
| 64 | + constraints += [tkw.WaveConstraint(N, BLOCK_N / 2)] |
| 65 | + |
| 66 | + constraints += [ |
| 67 | + tkw.HardwareConstraint( |
| 68 | + threads_per_wave=64, |
| 69 | + waves_per_block=(2, 2, 1), |
| 70 | + vector_shapes={B: 0}, |
| 71 | + mma_type=mfma_variant, |
| 72 | + ) |
| 73 | + ] |
| 74 | + |
| 75 | + # With dynamic dimensions, we need to add an assumption on how big |
| 76 | + # the reduction dimension is to determine whether we can schedule or not. |
| 77 | + if dynamic_dims: |
| 78 | + constraints += [tkw.Assumption(K > BLOCK_K * 4)] |
| 79 | + |
| 80 | + [weight_scale, input_scale, quant_dtype] = quant_params |
| 81 | + [qdtype_min, qdtype_max] = torch_dtype_range(quant_dtype) |
| 82 | + |
| 83 | + def clamp_tensor(source_reg, lower_bound, upper_bound): |
| 84 | + clamped = tkw.minimum(source_reg, upper_bound) |
| 85 | + clamped = tkw.maximum(clamped, lower_bound) |
| 86 | + clamped = tkw.cast(clamped, torch_dtype_to_wave(quant_dtype)) |
| 87 | + return clamped |
| 88 | + |
| 89 | + # Wave-level micro-kernel. |
| 90 | + # Since warps are not directly addressable, there is no |
| 91 | + # explicit notion of a warp id (like a workgroup or thread id). |
| 92 | + # This kernel uses the input sizes M, N, K throughout, as the tiling |
| 93 | + # and data movement strategy is determined during the compilation process. |
| 94 | + # These can be influenced by introducing constraints. |
| 95 | + def gemm_core(a, b, c_reg, result): |
| 96 | + # TODO: Registers for quantization scaling of inputs. Remove once scalar |
| 97 | + # codegen is enabled. |
| 98 | + a_scale = tkl.Register[B, M, K, tkl.f16](1 / input_scale.item()) |
| 99 | + a_clamp_max = tkl.Register[B, M, K, tkl.f16](qdtype_max) |
| 100 | + a_clamp_min = tkl.Register[B, M, K, tkl.f16](qdtype_min) |
| 101 | + b_scale = tkl.Register[N, K, tkl.f16](1 / weight_scale.item()) |
| 102 | + b_clamp_max = tkl.Register[N, K, tkl.f16](qdtype_max) |
| 103 | + b_clamp_min = tkl.Register[N, K, tkl.f16](qdtype_min) |
| 104 | + a_scale_deq = tkl.Register[B, M, N, tkl.f32](input_scale.item()) |
| 105 | + b_scale_deq = tkl.Register[B, M, N, tkl.f32](weight_scale.item()) |
| 106 | + |
| 107 | + @tkw.reduction(K, init_args=[c_reg]) |
| 108 | + def repeat( |
| 109 | + acc: tkl.Register[B, M, N, tkl.f32], |
| 110 | + ) -> tkl.Register[B, M, N, tkl.f32]: |
| 111 | + a_reg = tkw.read(a) |
| 112 | + b_reg = tkw.read(b) |
| 113 | + a_reg *= a_scale |
| 114 | + a_reg = clamp_tensor(a_reg, a_clamp_min, a_clamp_max) |
| 115 | + b_reg *= b_scale |
| 116 | + b_reg = clamp_tensor(b_reg, b_clamp_min, b_clamp_max) |
| 117 | + acc = tkw.mma(a_reg, b_reg, acc) |
| 118 | + acc *= a_scale_deq * b_scale_deq |
| 119 | + return acc |
| 120 | + |
| 121 | + tkw.write( |
| 122 | + tkw.cast(repeat, tkl.f16), |
| 123 | + result, |
| 124 | + ) |
| 125 | + |
| 126 | + @tkw.wave(constraints) |
| 127 | + def gemm( |
| 128 | + a: tkl.Memory[B, M, K, ADDRESS_SPACE, tkl.f16], |
| 129 | + b: tkl.Memory[N, K, ADDRESS_SPACE, tkl.f16], |
| 130 | + result: tkl.Memory[B, M, N, GLOBAL_ADDRESS_SPACE, tkl.f16], |
| 131 | + ): |
| 132 | + c_reg = tkl.Register[B, M, N, tkl.f32](0.0) |
| 133 | + gemm_core(a, b, c_reg, result) |
| 134 | + |
| 135 | + @tkw.wave(constraints) |
| 136 | + def gemm_with_bias( |
| 137 | + a: tkl.Memory[B, M, K, ADDRESS_SPACE, tkl.f16], |
| 138 | + b: tkl.Memory[N, K, ADDRESS_SPACE, tkl.f16], |
| 139 | + bias: tkl.Memory[N, ADDRESS_SPACE, tkl.f16], |
| 140 | + result: tkl.Memory[B, M, N, GLOBAL_ADDRESS_SPACE, tkl.f16], |
| 141 | + ): |
| 142 | + bias_reg = tkw.read(bias) |
| 143 | + bias_reg = tkw.broadcast(bias_reg, target_shape=[B, M, N]) |
| 144 | + bias_reg = tkw.cast(bias_reg, tkl.f32) |
| 145 | + # We can get "free" bias-add by setting bias as the initial |
| 146 | + # value of accumulator to the mma |
| 147 | + gemm_core(a, b, bias_reg, result) |
| 148 | + |
| 149 | + hyperparams = { |
| 150 | + ADDRESS_SPACE: SHARED_ADDRESS_SPACE, |
| 151 | + LOAD_ELEMS_PER_THREAD: get_mfma_load_elems_per_thread(mfma_variant), |
| 152 | + STORE_ELEMS_PER_THREAD: get_mfma_store_elems_per_thread(mfma_variant), |
| 153 | + BLOCK_B: 1, |
| 154 | + BLOCK_M: 64, |
| 155 | + BLOCK_N: 64, |
| 156 | + BLOCK_K: 32, |
| 157 | + N: shape[1], |
| 158 | + K: shape[0], |
| 159 | + } |
| 160 | + hyperparams.update(get_default_scheduling_params()) |
| 161 | + |
| 162 | + dynamic_symbols = [B, M] |
| 163 | + |
| 164 | + options = WaveCompileOptions( |
| 165 | + subs=hyperparams, |
| 166 | + canonicalize=True, |
| 167 | + dynamic_symbols=dynamic_symbols, |
| 168 | + dynamic_symbols_map={}, |
| 169 | + ) |
| 170 | + options = set_default_run_config(options) |
| 171 | + gemm_kernel = gemm |
| 172 | + if use_bias: |
| 173 | + gemm_kernel = gemm_with_bias |
| 174 | + compiled_gemm = wave_compile(options, gemm_kernel) |
| 175 | + return compiled_gemm |
| 176 | + |
| 177 | + |
| 178 | +def extract_quant_params(quant_params: dict): |
| 179 | + weight_scale = torch.tensor(quant_params["weight_scale"]).view( |
| 180 | + quant_params["weight_scale_shape"] |
| 181 | + ) |
| 182 | + input_scale = torch.tensor(quant_params["input_scale"]).view( |
| 183 | + quant_params["input_scale_shape"] |
| 184 | + ) |
| 185 | + qdtype = quant_params["qdtype"] |
| 186 | + return weight_scale, input_scale, qdtype |
| 187 | + |
| 188 | + |
| 189 | +LINEAR_SUPPORTED_DTYPE = {torch.float16} |
| 190 | + |
| 191 | +# Only per-tensor quantization is supported |
| 192 | +class WaveQuantLinear(nn.Module): |
| 193 | + """Fork of nn.Linear implementation but modified to handle Wave Kernel""" |
| 194 | + |
| 195 | + def __init__( |
| 196 | + self, |
| 197 | + in_features, |
| 198 | + out_features, |
| 199 | + quant_params, |
| 200 | + bias=True, |
| 201 | + device=None, |
| 202 | + dtype=None, |
| 203 | + ): |
| 204 | + device = device or torch.device("cuda:0") |
| 205 | + dtype = dtype or torch.float16 |
| 206 | + |
| 207 | + if device.type != "cuda": |
| 208 | + raise ValueError(f"{self.__class__.__name__} only support GPU device.") |
| 209 | + if dtype not in LINEAR_SUPPORTED_DTYPE: |
| 210 | + raise ValueError( |
| 211 | + f"{self.__class__.__name__} does not support dtype: {dtype}." |
| 212 | + ) |
| 213 | + factory_kwargs = {"device": device, "dtype": dtype} |
| 214 | + |
| 215 | + super().__init__() |
| 216 | + self.in_features = in_features |
| 217 | + self.out_features = out_features |
| 218 | + self.weight = nn.Parameter( |
| 219 | + torch.empty((out_features, in_features), **factory_kwargs) |
| 220 | + ) |
| 221 | + if bias: |
| 222 | + self.bias = nn.Parameter(torch.empty(out_features, **factory_kwargs)) |
| 223 | + else: |
| 224 | + self.register_parameter("bias", None) |
| 225 | + self.reset_parameters() |
| 226 | + # Wave related initialization |
| 227 | + self.weight_scale, self.input_scale, self.qdtype = extract_quant_params( |
| 228 | + quant_params |
| 229 | + ) |
| 230 | + if self.weight_scale.numel() != 1 or self.input_scale.numel() != 1: |
| 231 | + raise ValueError("Only per-tensor quantization is currently supported") |
| 232 | + self.kernel = get_quant_linear_kernel( |
| 233 | + [in_features, out_features], |
| 234 | + [self.weight_scale, self.input_scale, self.qdtype], |
| 235 | + use_bias=bias, |
| 236 | + ) |
| 237 | + |
| 238 | + def reset_parameters(self) -> None: |
| 239 | + # Setting a=sqrt(5) in kaiming_uniform is the same as initializing with |
| 240 | + # uniform(-1/sqrt(in_features), 1/sqrt(in_features)). For details, see |
| 241 | + # https://github.com/pytorch/pytorch/issues/57109 |
| 242 | + nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) |
| 243 | + if self.bias is not None: |
| 244 | + fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight) |
| 245 | + bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 |
| 246 | + nn.init.uniform_(self.bias, -bound, bound) |
| 247 | + |
| 248 | + def forward(self, input: torch.Tensor) -> torch.Tensor: |
| 249 | + assert len(input.shape) >= 2 |
| 250 | + # Determine parameter shapes |
| 251 | + input_len = input.shape[-2] |
| 252 | + batch = input.shape[0:-2] |
| 253 | + |
| 254 | + # Compute "flattened" batch shapes |
| 255 | + flat_batch = math.prod(batch) |
| 256 | + out_features = self.weight.shape[0] |
| 257 | + output_shape = [flat_batch, input_len, out_features] |
| 258 | + |
| 259 | + # Setup and run kernel |
| 260 | + output = torch.empty( |
| 261 | + output_shape, dtype=self.weight.dtype, device=self.weight.device |
| 262 | + ) |
| 263 | + self.kernel.options.dynamic_symbols_map = { |
| 264 | + tkl.sym.B: flat_batch, |
| 265 | + tkl.sym.M: input_len, |
| 266 | + } |
| 267 | + if self.bias is None: |
| 268 | + self.kernel( |
| 269 | + input.view(flat_batch, input_len, input.shape[-1]), self.weight, output |
| 270 | + ) |
| 271 | + else: |
| 272 | + self.kernel( |
| 273 | + input.view(flat_batch, input_len, input.shape[-1]), |
| 274 | + self.weight, |
| 275 | + self.bias, |
| 276 | + output, |
| 277 | + ) |
| 278 | + |
| 279 | + # Return non flattened shape |
| 280 | + return output.view(*batch, input_len, out_features) |
| 281 | + |
| 282 | + def extra_repr(self) -> str: |
| 283 | + return f"in_features={self.in_features}, out_features={self.out_features}, bias={self.bias is not None}" |
0 commit comments