Skip to content

Commit 12e2fdd

Browse files
[custom ops] Implement proper AOT dynamic shape support for custom ops. (#637)
After re-reading Edward's excellent dynamic shapes manual (https://bit.ly/3Q7Dc18), I realized some things weren't quite right. Those are corrected and all of the LLM ops are now tested for eager and AOT export with dynamic shapes. * Remove the assert in AOT custom ops for dynamic shapes not being implemented. * Use FakeTensorMode and ShapeEnv to create FakeTensors instead of real ones for AOT codegen. * Fix a bug where we were using signless types in AOT custom ops vs sign carrying (showed up in a uint8 test case). * Added KernelSelection.return_new_tensor() API to return a new tensor that is properly symbolic. * Broke all of the LLM kernels into individual files and tests. * Rewrote all of the shape checking code in the LLM kernels to use torch._check. Aside from being better ergonomics, this matches internal PyTorch conventions and has super powers (it adds shape constraints to the solver, resulting in very detailed shape info). * Added export test coverage for LLM ops.
1 parent e541aef commit 12e2fdd

14 files changed

+879
-564
lines changed

core/shark_turbine/runtime/op_reg/base.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,14 @@ def return_tensor(self, t: Tensor) -> "TensorArg":
350350
"""
351351
...
352352

353+
def return_new_tensor(self, size: list, dtype: torch.dtype) -> "TensorArg":
354+
"""Constructs a new symbolic tensor and marks the next result as returning it.
355+
356+
This delegates to `return_tensor` but takes care of some easy to mess
357+
up boiler plate for dynamic shapes.
358+
"""
359+
return self.return_tensor(torch.empty(size, dtype=dtype, device="meta"))
360+
353361

354362
class EagerKernelSelection(KernelSelection):
355363
"""Kernel selection specialized for eager arguments."""

core/shark_turbine/transforms/general/custom_op_expansion.py

Lines changed: 56 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,12 @@
44
# See https://llvm.org/LICENSE.txt for license information.
55
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66

7+
from typing import Callable
8+
79
import torch
810
from torch import Tensor
11+
from torch._subclasses.fake_tensor import FakeTensorMode
12+
from torch.fx.experimental.symbolic_shapes import ShapeEnv
913

1014
from ...dynamo.type_conversion import (
1115
NativeTypeConverter,
@@ -52,6 +56,8 @@ def __init__(
5256
self.ops_to_delete: dict[Operation, None] = {}
5357
self.type_converter = NativeTypeConverter(root_op.context)
5458
self.symbol_table = SymbolTable(root_op)
59+
self.shape_env = ShapeEnv()
60+
self.fake_mode = FakeTensorMode(shape_env=self.shape_env)
5561

5662
def delete_op(self, op):
5763
self.ops_to_delete[op.operation] = None
@@ -86,9 +92,15 @@ def expand_func(self, func_op: Operation):
8692
def expand_custom_op(self, op_reg: CustomOp, op: Operation):
8793
original_operands: list[Value] = list(op.operands)
8894
ksel = AOTKernelSelection(
89-
op_reg, original_operands, list(op.results), self.type_converter
95+
op_reg,
96+
original_operands,
97+
list(op.results),
98+
self.type_converter,
99+
self.shape_env,
90100
)
91-
op_reg.select(ksel)
101+
with self.fake_mode:
102+
op_reg.select(ksel)
103+
ksel._run_validators()
92104

93105
module_body = self.root_op.regions[0].blocks[0]
94106
kb = InlineKernelBuilder(
@@ -110,6 +122,8 @@ class AOTKernelSelection(KernelSelection):
110122
"operands",
111123
"results",
112124
"type_converter",
125+
"shape_env",
126+
"_validators",
113127
]
114128

115129
def __init__(
@@ -118,11 +132,18 @@ def __init__(
118132
operands: list[Value],
119133
results: list[Value],
120134
type_converter: NativeTypeConverter,
135+
shape_env: ShapeEnv,
121136
):
122137
super().__init__(op, len(operands))
123138
self.operands = operands
124139
self.results = results
125140
self.type_converter = type_converter
141+
self.shape_env = shape_env
142+
self._validators: list[Callable] = []
143+
144+
def _run_validators(self):
145+
for v in self._validators:
146+
v()
126147

127148
def arg_tensor(self, arg: int, *, inplace_tied: bool = False) -> TensorArg:
128149
# This is annoying: We have to go from the Torch MLIR type system to the
@@ -133,29 +154,55 @@ def arg_tensor(self, arg: int, *, inplace_tied: bool = False) -> TensorArg:
133154
arg_descs = self.arg_descs
134155
assert arg_descs[arg] is None, f"Already constrained argument {arg}"
135156
operand = self.operands[arg]
136-
signed_native_type = self.type_converter.torch_type_to_native(operand.type)
157+
signed_native_type = self.type_converter.torch_type_to_native(
158+
operand.type, signless=False
159+
)
137160
try:
138161
rtt = RankedTensorType(signed_native_type)
139-
# TODO: We need to do the FakeMode/ShapeEnv dance to create a symbolic
140-
# fake tensor here.
141162
except TypeError as e:
142163
raise TypeError(
143164
f"Argument type mismatch from Torch IR for arg {arg}: Expected ranked tensor, got {signed_native_type}"
144165
) from e
145-
assert not any(
146-
rtt.is_dynamic_dim(i) for i in range(rtt.rank)
147-
), "NYI: Dynamic shape tensors in custom op AOT mode"
148166
element_type_asm = str(rtt.element_type)
149167
try:
150168
dtype = MLIR_TYPE_ASM_TO_TORCH_DTYPE[element_type_asm]
151169
except KeyError as e:
152170
raise AssertionError(
153171
f"Could not find dtype mapping for {element_type_asm} in MLIR_TYPE_ASM_TO_TORCH_DTYPE"
154172
)
155-
t = torch.empty(rtt.shape, dtype=dtype, device="meta")
173+
174+
# Because we are operating in fake_mode, replace MLIR dyn dims with
175+
# symints for the PyTorch type system.
176+
shape_env = self.shape_env
177+
sym_shape = [
178+
d if d >= 0 else shape_env.create_unbacked_symint() for d in rtt.shape
179+
]
180+
t = torch.empty(sym_shape, dtype=dtype)
156181
arg_descs[arg] = desc = TensorArg(t)
157182
if inplace_tied:
158183
self.inplace_tied_arg_descs.append(desc)
184+
185+
def validator():
186+
rank = rtt.rank
187+
for i in range(rank):
188+
spec_dim = desc.spec_dims[i]
189+
if rtt.is_dynamic_dim(i):
190+
# Make sure that it wasn't specialized.
191+
if spec_dim is not None:
192+
raise ValueError(
193+
f"Custom op {self.op}, arg {arg} requires a static dim "
194+
f"at index {i} but it is dynamic: {rtt}"
195+
)
196+
else:
197+
# Make sure specialized dim matches.
198+
actual_dim = rtt.get_dim_size(i)
199+
if spec_dim is not None and actual_dim != spec_dim:
200+
raise ValueError(
201+
f"Custom op {self.op}, arg {arg} has a mismatched static "
202+
f"dim at index {i}: actual = {actual_dim}, expected = {spec_dim}"
203+
)
204+
205+
self._validators.append(validator)
159206
return desc
160207

161208
def arg_tensor_list(self, arg: int) -> TensorListArg:

llm/tests/ops/matmul_test.py

Lines changed: 0 additions & 103 deletions
This file was deleted.
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
# Copyright 2024 Advanced Micro Devices, Inc
2+
#
3+
# Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
# See https://llvm.org/LICENSE.txt for license information.
5+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
7+
import logging
8+
9+
logging.basicConfig(level=logging.DEBUG)
10+
11+
import unittest
12+
13+
import torch
14+
15+
from shark_turbine import aot
16+
from turbine_llm import ops
17+
from turbine_llm.types import layout_utils
18+
19+
20+
class mmt_block_scaled_offset_q4_unsigned_test(unittest.TestCase):
21+
def setUp(self):
22+
torch.manual_seed(42)
23+
24+
def test_basic(self):
25+
a = torch.rand([4, 16, 3200], dtype=torch.float32)
26+
d = torch.rand([3200, 100, 1], dtype=torch.float16)
27+
qs = (torch.rand([3200, 100, 16], dtype=torch.float32) * 32).to(torch.uint8)
28+
m = torch.rand([3200, 100, 1], dtype=torch.float16)
29+
result = ops.mmt_block_scaled_offset_q4_unsigned(a, d, qs, m)
30+
31+
# Dequantize and test with normal matmul.
32+
# Tolerances are empirical and results are not expected to match exactly.
33+
qs_i8 = layout_utils.promote_linear_i4_block_to_i8(qs)
34+
b = (d.to(torch.float32) * qs_i8.to(torch.float32) + m).flatten(1)
35+
torch.testing.assert_close(result, torch.matmul(a, b.T), atol=1e-1, rtol=1e-5)
36+
37+
def testExportDynamicDims(self):
38+
class MyModule(torch.nn.Module):
39+
def forward(self, a, d, qs, m):
40+
return ops.mmt_block_scaled_offset_q4_unsigned(a, d, qs, m)
41+
42+
mod = MyModule()
43+
batch = torch.export.Dim("batch")
44+
m = torch.export.Dim("m")
45+
ep = torch.export.export(
46+
mod,
47+
args=(
48+
torch.rand([4, 16, 3200], dtype=torch.float32),
49+
torch.rand([3200, 100, 1], dtype=torch.float16),
50+
(torch.rand([3200, 100, 16], dtype=torch.float32) * 32).to(torch.uint8),
51+
torch.rand([3200, 100, 1], dtype=torch.float16),
52+
),
53+
dynamic_shapes={
54+
"a": {0: batch, 1: m},
55+
"d": {},
56+
"qs": {},
57+
"m": {},
58+
},
59+
)
60+
asm = str(aot.export(ep).mlir_module)
61+
self.assertIn(
62+
"@turbine_llm_mmt_block_scaled_offset_q4_unsigned_3d_3200_3200_32_f32", asm
63+
)
64+
65+
66+
if __name__ == "__main__":
67+
unittest.main()
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# Copyright 2024 Advanced Micro Devices, Inc
2+
#
3+
# Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
# See https://llvm.org/LICENSE.txt for license information.
5+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
7+
import logging
8+
9+
logging.basicConfig(level=logging.DEBUG)
10+
11+
import unittest
12+
13+
import torch
14+
15+
from shark_turbine import aot
16+
from turbine_llm import ops
17+
18+
19+
class mmt_block_scaled_q8_test(unittest.TestCase):
20+
def setUp(self):
21+
torch.manual_seed(42)
22+
23+
def testF32BS32(self):
24+
a = torch.rand([4, 16, 3200], dtype=torch.float32)
25+
d = torch.rand([3200, 100, 1], dtype=torch.float16)
26+
qs = (torch.rand([3200, 100, 32], dtype=torch.float32) * 32.0).to(torch.int8)
27+
result = ops.mmt_block_scaled_q8(a, d, qs)
28+
29+
# Dequantize and test with normal matmul.
30+
# Tolerances are empirical and results are not expected to match exactly.
31+
b = (d.to(torch.float32) * qs.to(torch.float32)).flatten(1)
32+
torch.testing.assert_close(result, torch.matmul(a, b.T), atol=1e-1, rtol=1e-5)
33+
34+
def testExportDynamicDims(self):
35+
class MyModule(torch.nn.Module):
36+
def forward(self, a, b, qs):
37+
return ops.mmt_block_scaled_q8(a, b, qs)
38+
39+
mod = MyModule()
40+
batch = torch.export.Dim("batch")
41+
m = torch.export.Dim("m")
42+
ep = torch.export.export(
43+
mod,
44+
args=(
45+
torch.rand([4, 16, 3200], dtype=torch.float32),
46+
torch.rand([3200, 100, 1], dtype=torch.float16),
47+
(torch.rand([3200, 100, 32], dtype=torch.float32) * 32.0).to(
48+
torch.int8
49+
),
50+
),
51+
dynamic_shapes={
52+
"a": {0: batch, 1: m},
53+
"b": {},
54+
"qs": {},
55+
},
56+
)
57+
asm = str(aot.export(ep).mlir_module)
58+
self.assertIn("@turbine_llm_mmt_block_scaled_q8_3d_3200_3200_32_f32", asm)
59+
60+
61+
if __name__ == "__main__":
62+
unittest.main()

0 commit comments

Comments
 (0)