|
12 | 12 | "mmtfp", |
13 | 13 | "mmt_block_scaled_offset_q4_unsigned", |
14 | 14 | "mmt_block_scaled_q8", |
| 15 | + "mmt_super_block_scaled_offset_q4_unsigned", |
15 | 16 | ] |
16 | 17 |
|
17 | 18 |
|
@@ -95,6 +96,161 @@ def generate(self, ksel: KernelSelection, kb: KernelBuilder): |
95 | 96 | kb.yield_results(*call_function(target_function, *kb.arg_bindings)) |
96 | 97 |
|
97 | 98 |
|
| 99 | +@CustomOp.register(library=LIBRARY) |
| 100 | +class mmt_super_block_scaled_offset_q4_unsigned(CustomOp): |
| 101 | + """Super block scaled q4 matmul with transposed RHS. |
| 102 | +
|
| 103 | + Arguments: |
| 104 | +
|
| 105 | + * `a`: [B, M, K] |
| 106 | + * `d`: [N, SUP_COUNT, 1] |
| 107 | + * `dmin`: [N, SUP_COUNT, 1] |
| 108 | + * `sb_scales_hi`: [N, SUP_COUNT, SUB_COUNT // 4] |
| 109 | + * `sb_scales_lo`: [N, SUP_COUNT, SUB_COUNT // 2] |
| 110 | + * `sb_min_hi`: [N, SUP_COUNT, SUB_COUNT // 4] |
| 111 | + * `sb_mins_lo`: [N, SUP_COUNT, SUB_COUNT // 2] |
| 112 | + * `qs`: [N, SUP_COUNT, SUB_COUNT, BS // 2] |
| 113 | +
|
| 114 | + Where: `K == SUP_COUNT * SUB_COUNT * BS` |
| 115 | +
|
| 116 | + Given this and hi/lo combined into a single value, the dequantization |
| 117 | + formula is: |
| 118 | +
|
| 119 | + ``` |
| 120 | + d_scaled = (d * sb_scales).unsqueeze(-1) |
| 121 | + dmin_scaled = (dmin * sb_mins).unsqueeze(-1) |
| 122 | + return d_scaled * qs - dmin_scaled |
| 123 | + ``` |
| 124 | + """ |
| 125 | + |
| 126 | + signature = ( |
| 127 | + "mmt_super_block_scaled_offset_q4_unsigned(" |
| 128 | + "Tensor a, Tensor d, Tensor dmin, " |
| 129 | + "Tensor sb_scales_hi, Tensor sb_scales_low, " |
| 130 | + "Tensor sb_mins_hi, Tensor sb_mins_low, " |
| 131 | + "Tensor qs" |
| 132 | + ") -> (Tensor)" |
| 133 | + ) |
| 134 | + |
| 135 | + def select(self, ksel: KernelSelection): |
| 136 | + a_desc = ksel.arg_tensor(0) |
| 137 | + d_desc = ksel.arg_tensor(1) |
| 138 | + dmin_desc = ksel.arg_tensor(2) |
| 139 | + sb_scales_hi_desc = ksel.arg_tensor(3) |
| 140 | + sb_scales_low_desc = ksel.arg_tensor(4) |
| 141 | + sb_mins_hi_desc = ksel.arg_tensor(5) |
| 142 | + sb_mins_low_desc = ksel.arg_tensor(6) |
| 143 | + qs_desc = ksel.arg_tensor(7) |
| 144 | + |
| 145 | + # a arg |
| 146 | + *batch_dims, m, k = a_desc.t.shape |
| 147 | + a_desc.specialize_dims(-1) |
| 148 | + if not a_desc.t.dtype.is_floating_point: |
| 149 | + raise ValueError( |
| 150 | + f"mmt_super_block_scaled_offset_q4_unsigned arg 'a': Expected floating point (got {a_desc.t.dtype})" |
| 151 | + ) |
| 152 | + if len(batch_dims) != 1: |
| 153 | + raise ValueError( |
| 154 | + f"mmt_super_block_scaled_offset_q4_unsigned arg 'a': Expected 3d tensor (got {a_desc.t.shape})" |
| 155 | + ) |
| 156 | + |
| 157 | + # qs arg |
| 158 | + n, sup_count, sub_count, bs_div2 = qs_desc.t.shape |
| 159 | + qs_desc.specialize_all_dims() |
| 160 | + bs = bs_div2 * 2 |
| 161 | + if k != (sup_count * sub_count * bs): |
| 162 | + raise ValueError( |
| 163 | + f"mmt_super_block_scaled_offset_q4_unsigned arg 'qs': Incorrect shape (got {qs_desc.t.shape}, k={k})" |
| 164 | + ) |
| 165 | + |
| 166 | + # d arg |
| 167 | + v_n, v_sup_count, one = d_desc.t.shape |
| 168 | + d_desc.specialize_all_dims() |
| 169 | + if v_n != n or v_sup_count != sup_count or one != 1: |
| 170 | + raise ValueError( |
| 171 | + f"mmt_super_block_scaled_offset_q4_unsigned arg 'd': Incorrect shape (got {d_desc.t.shape})" |
| 172 | + ) |
| 173 | + |
| 174 | + # dmin arg |
| 175 | + v_n, v_sup_count, one = dmin_desc.t.shape |
| 176 | + dmin_desc.specialize_all_dims() |
| 177 | + if v_n != n or v_sup_count != sup_count or one != 1: |
| 178 | + raise ValueError( |
| 179 | + f"mmt_super_block_scaled_offset_q4_unsigned arg 'dmin': Incorrect shape (got {d_desc.t.shape})" |
| 180 | + ) |
| 181 | + |
| 182 | + # sb_scales_hi arg |
| 183 | + v_n, v_sup_count, v_sub_div4 = sb_scales_hi_desc.t.shape |
| 184 | + sb_scales_hi_desc.specialize_all_dims() |
| 185 | + if v_n != n or v_sup_count != sup_count or v_sub_div4 != (sub_count // 4): |
| 186 | + raise ValueError( |
| 187 | + f"mmt_super_block_scaled_offset_q4_unsigned arg 'sb_scales_hi': Incorrect shape (got {sb_scales_hi_desc.t.shape})" |
| 188 | + ) |
| 189 | + |
| 190 | + # sb_scales_low arg |
| 191 | + v_n, v_sup_count, v_sub_div2 = sb_scales_low_desc.t.shape |
| 192 | + sb_scales_low_desc.specialize_all_dims() |
| 193 | + if v_n != n or v_sup_count != sup_count or v_sub_div2 != (sub_count // 2): |
| 194 | + raise ValueError( |
| 195 | + f"mmt_super_block_scaled_offset_q4_unsigned arg 'sb_scales_low': Incorrect shape (got {sb_scales_low_desc.t.shape})" |
| 196 | + ) |
| 197 | + |
| 198 | + # sb_mins_hi arg |
| 199 | + v_n, v_sup_count, v_sub_div4 = sb_mins_hi_desc.t.shape |
| 200 | + sb_mins_hi_desc.specialize_all_dims() |
| 201 | + if v_n != n or v_sup_count != sup_count or v_sub_div4 != (sub_count // 4): |
| 202 | + raise ValueError( |
| 203 | + f"mmt_super_block_scaled_offset_q4_unsigned arg 'sb_mins_hi': Incorrect shape (got {sb_mins_hi_desc.t.shape})" |
| 204 | + ) |
| 205 | + |
| 206 | + # sb_mins_low arg |
| 207 | + v_n, v_sup_count, v_sub_div2 = sb_mins_low_desc.t.shape |
| 208 | + sb_mins_low_desc.specialize_all_dims() |
| 209 | + if v_n != n or v_sup_count != sup_count or v_sub_div2 != (sub_count // 2): |
| 210 | + raise ValueError( |
| 211 | + f"mmt_super_block_scaled_offset_q4_unsigned arg 'sb_mins_low': Incorrect shape (got {sb_mins_low_desc.t.shape})" |
| 212 | + ) |
| 213 | + |
| 214 | + # c return |
| 215 | + c = torch.empty(batch_dims + [m, n], dtype=a_desc.t.dtype) |
| 216 | + c_desc = ksel.return_tensor(c) # Shape batch..., m, n |
| 217 | + c_desc.specialize_dims(-1) |
| 218 | + |
| 219 | + def generate(self, ksel: KernelSelection, kb: KernelBuilder): |
| 220 | + a = kb.arg_value(0) |
| 221 | + a_tensor_type = RankedTensorType(a.type) |
| 222 | + *_, k = a_tensor_type.shape |
| 223 | + d = kb.arg_value(1) |
| 224 | + d_tensor_type = RankedTensorType(d.type) |
| 225 | + qs = kb.arg_value(7) |
| 226 | + qs_tensor_type = RankedTensorType(qs.type) |
| 227 | + n, sup_count, sub_count, bs_div2 = qs_tensor_type.shape |
| 228 | + bs = bs_div2 * 2 |
| 229 | + a_type_str = str(a_tensor_type.element_type) |
| 230 | + scale_type_str = str(d_tensor_type.element_type) |
| 231 | + |
| 232 | + template_file = "mmt_super_block_scaled_offset_q4_unsigned_3d.mlir" |
| 233 | + target_function_name = f"mmt_super_block_scaled_offset_q4_unsigned_3d_{n}_{k}_{sup_count}_{sub_count}_{bs}_{a_type_str}" |
| 234 | + |
| 235 | + target_function = inline_template_function( |
| 236 | + kb, |
| 237 | + template_file, |
| 238 | + target_function_name, |
| 239 | + n=n, |
| 240 | + k=k, |
| 241 | + sup_count=sup_count, |
| 242 | + sub_count=sub_count, |
| 243 | + sub_div4=sub_count // 4, |
| 244 | + sub_div2=sub_count // 2, |
| 245 | + bs=bs, |
| 246 | + bs_div2=bs_div2, |
| 247 | + a_type=a_type_str, |
| 248 | + scale_type=scale_type_str, |
| 249 | + ) |
| 250 | + kb.yield_results(*call_function(target_function, *kb.arg_bindings)) |
| 251 | + print(kb.module_body.owner) |
| 252 | + |
| 253 | + |
98 | 254 | @CustomOp.register(library=LIBRARY) |
99 | 255 | class mmt_block_scaled_q8(CustomOp): |
100 | 256 | """Generic block scaled matmul with transposed RHS. |
|
0 commit comments