Skip to content

Commit e19d5c6

Browse files
[llm] Implement mixed precision custom ops (#612)
Core changes: * Jit for CPU uses host features * Move llm env/config setup to core and activate via the `TURBINE_DEBUG=` env option (logging can be enabled via `TURBINE_DEBUG=log_level=debug` * Add a runtime Jit tracing facility that dumps directories of MLIR and benchmark configs if `TURBINE_DEBUG=runtime_trace_dir=/dir` * Fixes the merger utility to handle `util.func` LLM changes: * Large scale code re-reorganization to properly pull core types out of layer library * Add simple MLIR template setup for adding custom ops * Define custom ops: * `mmtfp`: Mixed precision matmul with a transposed-B on floating point values. * `mmt_block_scaled_q8`: Block scaled batch matmul with transposed-B operating on 8 bit quants/offsets (GGUF Q8_0 and Q8_1, although the latter is not fully wired). * `mmt_block_scaled_offset_q4_unsigned`: Block scaled batch matmul with transposed-B operating on 4 bit unsigned quants and offsets (GGUF Q4_1). * Basic numeric tests of the above compared against a PyTorch reference impl. * `QuantizedTensor` layouts for `Q8_0`, `Q4_K`, `Q4_1`, `Q6_K` (The `K` variants are currently just stubs so that we can inspect datasets that have them). * Generic planarized layout classes for `BlockScaledLayout` and `BlockScaledI4Layout` that the GGUF layouts unpack to when no other more specific implementation is present. * `InferenceOps` gets a custom subclass that dispatches to optimized kernels vs PyTorch reference implementations. * Added alias HF datasets for testing.
1 parent af54df5 commit e19d5c6

38 files changed

+1910
-325
lines changed

core/shark_turbine/runtime/device.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,10 @@ def _device_export_torch_tensor_cpu(
298298
}
299299

300300
DEVICE_TARGET_COMPILE_FLAGS: dict[str, tuple[str, ...]] = {
301-
"local-task": ("--iree-hal-target-backends=llvm-cpu",),
301+
"local-task": (
302+
"--iree-hal-target-backends=llvm-cpu",
303+
"--iree-llvmcpu-target-cpu-features=host",
304+
),
302305
}
303306

304307
# Aliases.

core/shark_turbine/runtime/op_reg/base.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
from typing import Any, Callable, Optional, Sequence, Type, Union, cast
1212

13-
from abc import ABC, abstractmethod, abstractproperty
13+
from abc import ABC, abstractmethod
1414
import functools
1515
import logging
1616
import re
@@ -37,6 +37,8 @@
3737
func_d,
3838
)
3939

40+
from ...support.logging import runtime_logger as logger
41+
4042
from ...support.conversions import (
4143
TORCH_DTYPE_TO_IREE_TYPE_ASM,
4244
)
@@ -53,7 +55,6 @@
5355
"def_library",
5456
]
5557

56-
logger = logging.getLogger("turbine.runtime.op_reg")
5758

5859
###############################################################################
5960
# Op library management
@@ -167,7 +168,8 @@ def __init__(
167168
fq_name = f"{library.ns}.{name}"
168169
ALL_CUSTOM_OP_REGS[fq_name] = self
169170

170-
@abstractproperty
171+
@property
172+
@abstractmethod
171173
def signature(self) -> str:
172174
"""PyTorch function signature.
173175
@@ -616,6 +618,7 @@ def __init__(
616618
self.arg_bindings = arg_bindings
617619
self.ip = ip
618620
self.module_body = module_body
621+
self.context = module_body.owner.context
619622
self.symbol_table = symbol_table
620623
self.yielded = False
621624

core/shark_turbine/runtime/op_reg/compiler.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from dataclasses import dataclass
88
from timeit import default_timer
9-
from typing import Any
9+
from typing import Any, Optional
1010

1111
from iree.compiler.api import (
1212
Session,
@@ -36,6 +36,8 @@
3636
Device,
3737
)
3838

39+
from ..tracing import tracer
40+
3941
from .base import (
4042
FreeFuncKernelBuilder,
4143
KernelSelection,
@@ -69,6 +71,10 @@ class KernelCompileConfig:
6971
# things like unbacked memory mappings, etc.
7072
keep_alive: Any = None
7173

74+
# If tracing is enabled, this may contain a sanitized key that can be
75+
# used to log additional information against the kernel.
76+
tracing_key: Optional[str] = None
77+
7278

7379
# TODO: The cache should be more than just a simple dict. Can be persistent
7480
KERNEL_CACHE: dict[str, tuple[VmContext, VmFunction, KernelCompileConfig]] = {}
@@ -95,7 +101,7 @@ def compile_standalone_kernel(
95101
ksel.op.generate(ksel, kb)
96102
kb.module_op.verify()
97103
module_asm = kb.module_op.get_asm(
98-
binary=True, enable_debug_info=True, print_generic_op_form=True
104+
binary=True, enable_debug_info=True, print_generic_op_form=False
99105
)
100106
generation_time = default_timer() - start
101107

@@ -129,14 +135,23 @@ def compile_standalone_kernel(
129135
vm_context = VmContext(vm_instance, [device.create_hal_module(), vm_module])
130136
main_function = vm_module.lookup_function("main")
131137

132-
logger.debug(
133-
"Compiled kernel %s: mlir=%d bytes, vmfb=%d bytes (generation: %sms, compilation: %sms)",
134-
cache_key,
135-
len(module_asm),
136-
len(mapped_memory),
137-
generation_time * 1000,
138-
compilation_time * 1000,
139-
)
138+
if tracer.enabled:
139+
config.tracing_key = tracer.save_jit_kernel_artifacts(
140+
cache_key=cache_key, module_asm=module_asm, binary=mapped_memory
141+
)
142+
tracer.log_structured(
143+
tag="COMPILE",
144+
msg=f"Compiled kernel {config.tracing_key}, cache_key={cache_key}",
145+
columns=[
146+
config.tracing_key,
147+
main_function.name,
148+
len(module_asm),
149+
len(mapped_memory),
150+
generation_time * 1000,
151+
compilation_time * 1000,
152+
" ".join(session.get_flags(non_default_only=True)),
153+
],
154+
)
140155
cache_hit = (vm_context, main_function, config)
141156
KERNEL_CACHE[cache_key] = cache_hit
142157
return cache_hit

core/shark_turbine/runtime/op_reg/eager.py

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313

1414
from iree.runtime import (
1515
HalBufferView,
16+
HalElementType,
17+
VmRef,
1618
VmVariantList,
1719
)
1820

@@ -29,6 +31,8 @@
2931
lookup_device_from_torch,
3032
)
3133

34+
from ..tracing import tracer
35+
3236
from .base import (
3337
AttrArg,
3438
IntArg,
@@ -37,6 +41,7 @@
3741

3842
from .compiler import (
3943
compile_standalone_kernel,
44+
KernelCompileConfig,
4045
)
4146

4247
__all__ = [
@@ -153,7 +158,8 @@ def push_tensor(tensor_arg):
153158
start = default_timer()
154159
vm_context.invoke(vm_f, arg_list, ret_list)
155160
invoke_time = default_timer() - start
156-
logger.debug("Kernel invocation %s: %sms", config.key, invoke_time * 1000)
161+
if tracer.enabled:
162+
_log_eager_dispatch(config, arg_list, invoke_time * 1000)
157163

158164
# Unpack results.
159165
results = []
@@ -179,3 +185,62 @@ def push_tensor(tensor_arg):
179185
return None
180186
else:
181187
return tuple(results)
188+
189+
190+
def _log_eager_dispatch(
191+
config: KernelCompileConfig, arg_list: VmVariantList, invoke_time_millis: float
192+
):
193+
args = []
194+
try:
195+
for i in range(arg_list.size):
196+
variant = arg_list.get_variant(i)
197+
if isinstance(variant, VmRef):
198+
if variant.isinstance(HalBufferView):
199+
args.append(_log_format_buffer_view(variant.deref(HalBufferView)))
200+
continue
201+
args.append(variant)
202+
except:
203+
tracer.exception("Exception while pretty-printing arguments")
204+
205+
msg = ""
206+
tracer.log_structured(
207+
tag="INVOKE_KERNEL",
208+
msg=msg,
209+
columns=[config.tracing_key, invoke_time_millis] + args,
210+
)
211+
212+
213+
def _log_format_buffer_view(bv: HalBufferView) -> str:
214+
# TODO: We should expose this as a method on HalBufferView upstream instead
215+
# of half doing it here.
216+
shape = "x".join(str(i) for i in bv.shape)
217+
dtype_desc = _LOG_HAL_ELEMENT_TYPE_DESC.get(bv.element_type)
218+
if dtype_desc is None:
219+
dtype_desc = f"<{bv.element_type}>"
220+
return f"{shape}x{dtype_desc}"
221+
222+
223+
_LOG_HAL_ELEMENT_TYPE_DESC = {
224+
HalElementType.BFLOAT_16: "bf16",
225+
HalElementType.BOOL_8: "i1",
226+
HalElementType.COMPLEX_64: "cf64",
227+
HalElementType.COMPLEX_128: "cf128",
228+
HalElementType.FLOAT_16: "f16",
229+
HalElementType.FLOAT_32: "f32",
230+
HalElementType.FLOAT_64: "f64",
231+
HalElementType.INT_4: "i4",
232+
HalElementType.INT_8: "i8",
233+
HalElementType.INT_16: "i16",
234+
HalElementType.INT_32: "i32",
235+
HalElementType.INT_64: "i64",
236+
HalElementType.SINT_4: "si4",
237+
HalElementType.SINT_8: "si8",
238+
HalElementType.SINT_16: "si16",
239+
HalElementType.SINT_32: "si32",
240+
HalElementType.SINT_64: "si64",
241+
HalElementType.UINT_4: "ui4",
242+
HalElementType.UINT_8: "ui8",
243+
HalElementType.UINT_16: "ui16",
244+
HalElementType.UINT_32: "ui32",
245+
HalElementType.UINT_64: "ui64",
246+
}
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
# Copyright 2024 Advanced Micro Devices, Inc
2+
#
3+
# Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
# See https://llvm.org/LICENSE.txt for license information.
5+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
7+
import hashlib
8+
import os
9+
from pathlib import Path
10+
import logging
11+
12+
from ..support.debugging import flags
13+
from ..support.logging import get_logger, DefaultFormatter
14+
15+
logger = get_logger("turbine.runtime")
16+
17+
18+
class RuntimeTracer:
19+
"""Supports fine grained tracing of runtime interactions.
20+
21+
The default implementation no-ops.
22+
"""
23+
24+
__slots__ = ["enabled"]
25+
26+
def __init__(self):
27+
self.enabled: bool = False
28+
29+
def save_jit_kernel_artifacts(
30+
self, *, cache_key: str, module_asm: bytes, binary: memoryview
31+
) -> str:
32+
return cache_key
33+
34+
def info(self, msg, *args, **kwargs):
35+
...
36+
37+
def error(self, msg, *args, **kwargs):
38+
...
39+
40+
def exception(self, msg, *args, **kwargs):
41+
...
42+
43+
def log_structured(self, *, tag: str, msg: str, columns: list):
44+
...
45+
46+
47+
class DirectoryTracer(RuntimeTracer):
48+
__slots__ = [
49+
"dir",
50+
"logger",
51+
]
52+
53+
def __init__(self, dir: Path):
54+
self.dir = dir
55+
self.enabled = True
56+
# Configure a root logger that outputs what we want.
57+
trace_logger = self.logger = logging.getLogger("turbine.runtime.tracer")
58+
log_file = dir / "runtime.log"
59+
trace_logger.setLevel(logging.DEBUG)
60+
handler = logging.FileHandler(log_file)
61+
handler.setFormatter(DefaultFormatter())
62+
trace_logger.addHandler(handler)
63+
trace_logger.propagate = False
64+
logger.info(f"Set up turbine runtime tracing to %s", log_file)
65+
trace_logger.info("Started process %d", os.getpid())
66+
67+
def save_jit_kernel_artifacts(
68+
self, *, cache_key: str, module_asm: bytes, binary: memoryview
69+
) -> str:
70+
hasher = hashlib.sha1(cache_key.encode(), usedforsecurity=False)
71+
tracing_key = hasher.digest().hex()
72+
try:
73+
with open(self.dir / f"{tracing_key}.mlir", "wb") as f:
74+
f.write(module_asm)
75+
with open(self.dir / f"{tracing_key}.vmfb", "wb") as f:
76+
f.write(binary)
77+
except IOError:
78+
self.logger.exception(f"Error saving artifact for {tracing_key}")
79+
finally:
80+
self.logger.info(f"Saved artifacts for {tracing_key}")
81+
return tracing_key
82+
83+
def info(self, msg, *args, **kwargs):
84+
self.logger.info(msg, *args, **kwargs)
85+
86+
def error(self, msg, *args, **kwargs):
87+
self.logger.error(msg, *args, **kwargs)
88+
89+
def exception(self, msg, *args, **kwargs):
90+
self.logger.exception(msg, *args, **kwargs, stacklevel=2)
91+
92+
def log_structured(self, *, tag: str, msg: str, columns: list):
93+
columns_joined = "\t".join(str(c) for c in columns)
94+
self.logger.info("%s\n::%s\t%s", msg, tag, columns_joined)
95+
96+
97+
# Determine whether configured to do real tracing.
98+
def _setup_default_tracer() -> RuntimeTracer:
99+
if flags.runtime_trace_dir:
100+
try:
101+
trace_dir = Path(flags.runtime_trace_dir)
102+
trace_dir.mkdir(parents=True, exist_ok=True)
103+
return DirectoryTracer(trace_dir)
104+
except IOError:
105+
logger.exception("Error configuring runtime tracing to: %s", trace_dir)
106+
return RuntimeTracer()
107+
108+
return RuntimeTracer()
109+
110+
111+
tracer: RuntimeTracer = _setup_default_tracer()

core/shark_turbine/support/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,6 @@
44
# See https://llvm.org/LICENSE.txt for license information.
55
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66

7+
# Debugging must be loaded first as other low level things depend on it.
8+
from .debugging import *
79
from .exceptions import *

0 commit comments

Comments
 (0)