Skip to content

Commit e541aef

Browse files
[llm] Add Q4_K quantization. (#628)
1 parent 9ca50de commit e541aef

File tree

11 files changed

+735
-14
lines changed

11 files changed

+735
-14
lines changed

core/shark_turbine/runtime/op_reg/base.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
dispatcher.
99
"""
1010

11-
from typing import Any, Callable, Optional, Sequence, Type, Union, cast
11+
from typing import Any, Callable, List, Optional, Sequence, Type, Union, cast
1212

1313
from abc import ABC, abstractmethod
1414
import functools
@@ -478,6 +478,9 @@ def mlir_type_asm(self) -> str:
478478
return "i64"
479479

480480

481+
_NoneInt: Optional[int] = None
482+
483+
481484
class TensorArg:
482485
__slots__ = [
483486
"t",
@@ -491,13 +494,25 @@ class TensorArg:
491494
def __init__(self, t: Tensor):
492495
self.t = t
493496
# Any static dims that we are specializing. Defaults to all dynamic.
494-
self.spec_dims: Sequence[Optional[int]] = len(t.shape) * [None]
497+
self.spec_dims = len(t.shape) * [_NoneInt]
495498
# All descriptors have an attribute to indicate their value
496499
# as a tensor, and those that aren't are fixated to None.
497500
# This is to enable fast lookup in the hot path of determining
498501
# how to dispatch.
499502
self.maybe_tensor_value: Tensor = t
500503

504+
def specialize_all_dims(self):
505+
"""Marks all dimensions as specialized."""
506+
self.spec_dims = list(self.t.shape)
507+
508+
def specialize_dims(self, *indices: int):
509+
"""Specializes individual dimensions.
510+
511+
`i` can have negative indexing.
512+
"""
513+
for i in indices:
514+
self.spec_dims[i] = self.t.size(i)
515+
501516
def __repr__(self):
502517
return (
503518
f"TensorArg(shape={self.t.shape}, dtype={self.t.dtype}, "

llm/tests/ops/matmul_test.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@ def test3DF32(self):
3434

3535

3636
class mmt_block_scaled_q8_test(unittest.TestCase):
37+
def setUp(self):
38+
torch.manual_seed(42)
39+
3740
def testF32BS32(self):
3841
a = torch.rand([4, 16, 3200], dtype=torch.float32)
3942
d = torch.rand([3200, 100, 1], dtype=torch.float16)
@@ -47,6 +50,9 @@ def testF32BS32(self):
4750

4851

4952
class mmt_block_scaled_offset_q4_unsigned_test(unittest.TestCase):
53+
def setUp(self):
54+
torch.manual_seed(42)
55+
5056
def test_basic(self):
5157
a = torch.rand([4, 16, 3200], dtype=torch.float32)
5258
d = torch.rand([3200, 100, 1], dtype=torch.float16)
@@ -61,5 +67,37 @@ def test_basic(self):
6167
torch.testing.assert_close(result, torch.matmul(a, b.T), atol=1e-1, rtol=1e-5)
6268

6369

70+
class mmt_super_block_scaled_offset_q4_unsigned(unittest.TestCase):
71+
def setUp(self):
72+
torch.manual_seed(42)
73+
74+
@unittest.skip(
75+
"compiler bad tile selection:"
76+
"https://github.com/openxla/iree/issues/17078#issuecomment-2062331207"
77+
)
78+
def test_basic(self):
79+
# n = 2560, k = 5120, sup = 20, sub = 8, bs = 32
80+
a = torch.rand([4, 16, 5120], dtype=torch.float32)
81+
d = torch.rand([2560, 20, 1], dtype=torch.float16)
82+
dmin = torch.rand([2560, 20, 1], dtype=torch.float16)
83+
sb_scales_hi = (torch.rand([2560, 20, 2], dtype=torch.float32) * 127).to(
84+
torch.uint8
85+
)
86+
sb_scales_low = (torch.rand([2560, 20, 4], dtype=torch.float32) * 127).to(
87+
torch.uint8
88+
)
89+
sb_mins_hi = (torch.rand([2560, 20, 2], dtype=torch.float32) * 127).to(
90+
torch.uint8
91+
)
92+
sb_mins_low = (torch.rand([2560, 20, 4], dtype=torch.float32) * 127).to(
93+
torch.uint8
94+
)
95+
qs = (torch.rand([2560, 20, 8, 16], dtype=torch.float32) * 127).to(torch.uint8)
96+
result = ops.mmt_super_block_scaled_offset_q4_unsigned(
97+
a, d, dmin, sb_scales_hi, sb_scales_low, sb_mins_hi, sb_mins_low, qs
98+
)
99+
# TODO: Validate numerics once enabled and crash bug fixed.
100+
101+
64102
if __name__ == "__main__":
65103
unittest.main()

llm/tests/types/layout_utils_test.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,35 @@ def test_promote_i4_block_to_i8_signed(self):
6464
r0,
6565
)
6666

67+
def test_promote_i2_block_to_i8(self):
68+
data = torch.tensor([[0xC1, 0xB2, 0xA3, 0x94, 0x85]], dtype=torch.uint8)
69+
expected = torch.tensor(
70+
# fmt: off
71+
[[
72+
1, 0, 0, 3, # 0xC1
73+
2, 0, 3, 2, # 0xB2
74+
3, 0, 2, 2, # 0xA3
75+
0, 1, 1, 2, # 0x94
76+
1, 1, 0, 2 # 0x85
77+
]],
78+
dtype=torch.uint8,
79+
# fmt: on
80+
)
81+
r0 = promote_linear_i2_block_to_i8(data)
82+
torch.testing.assert_close(r0, expected)
83+
84+
def test_promote_i6_block_to_i8(self):
85+
# High 2 bit values: 0, 3, 1, 3, 1, 3, 0, 3
86+
high = torch.tensor([[0xDC, 0xCD]], dtype=torch.uint8)
87+
# Low 4 bit values:
88+
# '0xb', '0xc', '0x2', '0x3', '0x1', '0x1', '0x6', '0x7'
89+
low = torch.tensor([[0xCB, 0x32, 0x11, 0x76]], dtype=torch.uint8)
90+
r0 = promote_linear_i6_block_to_i8(high, low)
91+
r_debug = repr(debug_map_tensor_as_hex_string(r0))
92+
self.assertEqual(
93+
r_debug, "[['0xb', '0x3c', '0x12', '0x33', '0x11', '0x31', '0x6', '0x37']]"
94+
)
95+
6796

6897
if __name__ == "__main__":
6998
unittest.main()

llm/turbine_llm/ops/custom_inference_ops.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,15 @@
1818
InferenceTensor,
1919
PrimitiveTensor,
2020
QuantizedTensor,
21+
SuperBlockOffsetScaled_4_6_Layout,
2122
gguf_interop,
2223
)
2324

2425
from .matmul import (
2526
mmtfp,
2627
mmt_block_scaled_offset_q4_unsigned,
2728
mmt_block_scaled_q8,
29+
mmt_super_block_scaled_offset_q4_unsigned,
2830
)
2931

3032
__all__ = [
@@ -59,7 +61,7 @@ def _matmul(
5961
return NotImplemented
6062

6163
# Handle quantized tensor layout switched.
62-
handler = _QMMT_DISPATCH.get(type(rhs))
64+
handler = _QMMT_DISPATCH.get(rhs.layout_type)
6365
if handler is None:
6466
return NotImplemented
6567
return handler(lhs, rhs)
@@ -87,7 +89,26 @@ def _mmt_block_scaled_q4(lhs: torch.Tensor, rhs: QuantizedTensor[BlockScaledI4La
8789
)
8890

8991

92+
def _mmt_super_block_offset_scaled_4_6_q4(
93+
lhs: torch.Tensor, rhs: QuantizedTensor[SuperBlockOffsetScaled_4_6_Layout]
94+
):
95+
rhs_unpacked = rhs.unpack()
96+
sb_scales_hi, sb_scales_low = rhs_unpacked.sb_scales_bit_packed
97+
sb_mins_hi, sb_mins_low = rhs_unpacked.sb_mins_bit_packed
98+
return mmt_super_block_scaled_offset_q4_unsigned(
99+
lhs,
100+
rhs_unpacked.d,
101+
rhs_unpacked.dmin,
102+
sb_scales_hi,
103+
sb_scales_low,
104+
sb_mins_hi,
105+
sb_mins_low,
106+
rhs_unpacked.qs_bit_packed,
107+
)
108+
109+
90110
_QMMT_DISPATCH: dict[type, Callable] = {
91-
gguf_interop.Q4_1: _mmt_block_scaled_q4,
92-
gguf_interop.Q8_0: _mmt_block_scaled,
111+
BlockScaledI4Layout: _mmt_block_scaled_q4,
112+
BlockScaledLayout: _mmt_block_scaled,
113+
SuperBlockOffsetScaled_4_6_Layout: _mmt_super_block_offset_scaled_4_6_q4,
93114
}

llm/turbine_llm/ops/matmul.py

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
"mmtfp",
1313
"mmt_block_scaled_offset_q4_unsigned",
1414
"mmt_block_scaled_q8",
15+
"mmt_super_block_scaled_offset_q4_unsigned",
1516
]
1617

1718

@@ -95,6 +96,161 @@ def generate(self, ksel: KernelSelection, kb: KernelBuilder):
9596
kb.yield_results(*call_function(target_function, *kb.arg_bindings))
9697

9798

99+
@CustomOp.register(library=LIBRARY)
100+
class mmt_super_block_scaled_offset_q4_unsigned(CustomOp):
101+
"""Super block scaled q4 matmul with transposed RHS.
102+
103+
Arguments:
104+
105+
* `a`: [B, M, K]
106+
* `d`: [N, SUP_COUNT, 1]
107+
* `dmin`: [N, SUP_COUNT, 1]
108+
* `sb_scales_hi`: [N, SUP_COUNT, SUB_COUNT // 4]
109+
* `sb_scales_lo`: [N, SUP_COUNT, SUB_COUNT // 2]
110+
* `sb_min_hi`: [N, SUP_COUNT, SUB_COUNT // 4]
111+
* `sb_mins_lo`: [N, SUP_COUNT, SUB_COUNT // 2]
112+
* `qs`: [N, SUP_COUNT, SUB_COUNT, BS // 2]
113+
114+
Where: `K == SUP_COUNT * SUB_COUNT * BS`
115+
116+
Given this and hi/lo combined into a single value, the dequantization
117+
formula is:
118+
119+
```
120+
d_scaled = (d * sb_scales).unsqueeze(-1)
121+
dmin_scaled = (dmin * sb_mins).unsqueeze(-1)
122+
return d_scaled * qs - dmin_scaled
123+
```
124+
"""
125+
126+
signature = (
127+
"mmt_super_block_scaled_offset_q4_unsigned("
128+
"Tensor a, Tensor d, Tensor dmin, "
129+
"Tensor sb_scales_hi, Tensor sb_scales_low, "
130+
"Tensor sb_mins_hi, Tensor sb_mins_low, "
131+
"Tensor qs"
132+
") -> (Tensor)"
133+
)
134+
135+
def select(self, ksel: KernelSelection):
136+
a_desc = ksel.arg_tensor(0)
137+
d_desc = ksel.arg_tensor(1)
138+
dmin_desc = ksel.arg_tensor(2)
139+
sb_scales_hi_desc = ksel.arg_tensor(3)
140+
sb_scales_low_desc = ksel.arg_tensor(4)
141+
sb_mins_hi_desc = ksel.arg_tensor(5)
142+
sb_mins_low_desc = ksel.arg_tensor(6)
143+
qs_desc = ksel.arg_tensor(7)
144+
145+
# a arg
146+
*batch_dims, m, k = a_desc.t.shape
147+
a_desc.specialize_dims(-1)
148+
if not a_desc.t.dtype.is_floating_point:
149+
raise ValueError(
150+
f"mmt_super_block_scaled_offset_q4_unsigned arg 'a': Expected floating point (got {a_desc.t.dtype})"
151+
)
152+
if len(batch_dims) != 1:
153+
raise ValueError(
154+
f"mmt_super_block_scaled_offset_q4_unsigned arg 'a': Expected 3d tensor (got {a_desc.t.shape})"
155+
)
156+
157+
# qs arg
158+
n, sup_count, sub_count, bs_div2 = qs_desc.t.shape
159+
qs_desc.specialize_all_dims()
160+
bs = bs_div2 * 2
161+
if k != (sup_count * sub_count * bs):
162+
raise ValueError(
163+
f"mmt_super_block_scaled_offset_q4_unsigned arg 'qs': Incorrect shape (got {qs_desc.t.shape}, k={k})"
164+
)
165+
166+
# d arg
167+
v_n, v_sup_count, one = d_desc.t.shape
168+
d_desc.specialize_all_dims()
169+
if v_n != n or v_sup_count != sup_count or one != 1:
170+
raise ValueError(
171+
f"mmt_super_block_scaled_offset_q4_unsigned arg 'd': Incorrect shape (got {d_desc.t.shape})"
172+
)
173+
174+
# dmin arg
175+
v_n, v_sup_count, one = dmin_desc.t.shape
176+
dmin_desc.specialize_all_dims()
177+
if v_n != n or v_sup_count != sup_count or one != 1:
178+
raise ValueError(
179+
f"mmt_super_block_scaled_offset_q4_unsigned arg 'dmin': Incorrect shape (got {d_desc.t.shape})"
180+
)
181+
182+
# sb_scales_hi arg
183+
v_n, v_sup_count, v_sub_div4 = sb_scales_hi_desc.t.shape
184+
sb_scales_hi_desc.specialize_all_dims()
185+
if v_n != n or v_sup_count != sup_count or v_sub_div4 != (sub_count // 4):
186+
raise ValueError(
187+
f"mmt_super_block_scaled_offset_q4_unsigned arg 'sb_scales_hi': Incorrect shape (got {sb_scales_hi_desc.t.shape})"
188+
)
189+
190+
# sb_scales_low arg
191+
v_n, v_sup_count, v_sub_div2 = sb_scales_low_desc.t.shape
192+
sb_scales_low_desc.specialize_all_dims()
193+
if v_n != n or v_sup_count != sup_count or v_sub_div2 != (sub_count // 2):
194+
raise ValueError(
195+
f"mmt_super_block_scaled_offset_q4_unsigned arg 'sb_scales_low': Incorrect shape (got {sb_scales_low_desc.t.shape})"
196+
)
197+
198+
# sb_mins_hi arg
199+
v_n, v_sup_count, v_sub_div4 = sb_mins_hi_desc.t.shape
200+
sb_mins_hi_desc.specialize_all_dims()
201+
if v_n != n or v_sup_count != sup_count or v_sub_div4 != (sub_count // 4):
202+
raise ValueError(
203+
f"mmt_super_block_scaled_offset_q4_unsigned arg 'sb_mins_hi': Incorrect shape (got {sb_mins_hi_desc.t.shape})"
204+
)
205+
206+
# sb_mins_low arg
207+
v_n, v_sup_count, v_sub_div2 = sb_mins_low_desc.t.shape
208+
sb_mins_low_desc.specialize_all_dims()
209+
if v_n != n or v_sup_count != sup_count or v_sub_div2 != (sub_count // 2):
210+
raise ValueError(
211+
f"mmt_super_block_scaled_offset_q4_unsigned arg 'sb_mins_low': Incorrect shape (got {sb_mins_low_desc.t.shape})"
212+
)
213+
214+
# c return
215+
c = torch.empty(batch_dims + [m, n], dtype=a_desc.t.dtype)
216+
c_desc = ksel.return_tensor(c) # Shape batch..., m, n
217+
c_desc.specialize_dims(-1)
218+
219+
def generate(self, ksel: KernelSelection, kb: KernelBuilder):
220+
a = kb.arg_value(0)
221+
a_tensor_type = RankedTensorType(a.type)
222+
*_, k = a_tensor_type.shape
223+
d = kb.arg_value(1)
224+
d_tensor_type = RankedTensorType(d.type)
225+
qs = kb.arg_value(7)
226+
qs_tensor_type = RankedTensorType(qs.type)
227+
n, sup_count, sub_count, bs_div2 = qs_tensor_type.shape
228+
bs = bs_div2 * 2
229+
a_type_str = str(a_tensor_type.element_type)
230+
scale_type_str = str(d_tensor_type.element_type)
231+
232+
template_file = "mmt_super_block_scaled_offset_q4_unsigned_3d.mlir"
233+
target_function_name = f"mmt_super_block_scaled_offset_q4_unsigned_3d_{n}_{k}_{sup_count}_{sub_count}_{bs}_{a_type_str}"
234+
235+
target_function = inline_template_function(
236+
kb,
237+
template_file,
238+
target_function_name,
239+
n=n,
240+
k=k,
241+
sup_count=sup_count,
242+
sub_count=sub_count,
243+
sub_div4=sub_count // 4,
244+
sub_div2=sub_count // 2,
245+
bs=bs,
246+
bs_div2=bs_div2,
247+
a_type=a_type_str,
248+
scale_type=scale_type_str,
249+
)
250+
kb.yield_results(*call_function(target_function, *kb.arg_bindings))
251+
print(kb.module_body.owner)
252+
253+
98254
@CustomOp.register(library=LIBRARY)
99255
class mmt_block_scaled_q8(CustomOp):
100256
"""Generic block scaled matmul with transposed RHS.

0 commit comments

Comments
 (0)