Skip to content

Commit a176f99

Browse files
Start adding the library of custom ops. (#296)
Defines a couple of IREE builtins: * `ops.iree.trace_tensor` * `ops.iree.trace_tensors` Extends the infra for better support: * Adds support for `Tensor[]` arguments to custom ops.
1 parent 6f67a97 commit a176f99

File tree

18 files changed

+1018
-110
lines changed

18 files changed

+1018
-110
lines changed

python/shark_turbine/aot/support/ir_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -246,8 +246,10 @@ def create_tensor_global(
246246
array = np.array(detached_tensor)
247247
# We know that a Numpy array is a ReadableBuffer so ignore type error.
248248
contents = memoryview(array) # type: ignore
249+
shape_desc = "_".join([str(d) for d in t.shape])
250+
blob_name = f"torch_tensor_{shape_desc}_{str(t.dtype)}"
249251
elements_attr = DenseResourceElementsAttr.get_from_buffer(
250-
contents, "from_py", tensor_type
252+
contents, blob_name, tensor_type
251253
)
252254
ir_attrs["initial_value"] = elements_attr
253255

python/shark_turbine/dynamo/type_conversion.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def __init__(self, context: Context):
4646
self.torch_type_to_native
4747
)
4848

49-
def torch_type_to_native(self, torch_type: IrType) -> IrType:
49+
def torch_type_to_native(self, torch_type: IrType, signless: bool = True) -> IrType:
5050
"""Converts a presumed torch type to a corresponding native type.
5151
5252
This mirrors the type conversion in torch-mlir's BackendTypeConversion.cpp.
@@ -56,6 +56,8 @@ def torch_type_to_native(self, torch_type: IrType) -> IrType:
5656
!torch.float -> f64
5757
!torch.bool -> i1
5858
!torch.vtensor -> tensor
59+
60+
If `signless=False`, then integer types will retain their signs.
5961
"""
6062
# We don't presently have API support for introspecting torch type,
6163
# and even if we did, it is likely that this is more efficient.
@@ -66,7 +68,11 @@ def torch_type_to_native(self, torch_type: IrType) -> IrType:
6668
if name == "bool":
6769
return IntegerType.get_signless(1)
6870
if name == "int":
69-
return IntegerType.get_signless(64)
71+
return (
72+
IntegerType.get_signless(64)
73+
if signless
74+
else IntegerType.get_signed(64)
75+
)
7076
elif name == "float":
7177
return F64Type.get()
7278
elif name == "vtensor":
@@ -75,22 +81,25 @@ def torch_type_to_native(self, torch_type: IrType) -> IrType:
7581
dim_list_str, dtype_str = tm.groups()
7682
dim_list = parse_tensor_dim_list(dim_list_str)
7783
dtype = self.convert_torch_element_type_to_native(
78-
IrType.parse(dtype_str)
84+
IrType.parse(dtype_str), signless=signless
7985
)
8086
# TODO: Eliminate RankedTensorType dependence on Location.
8187
# See: https://github.com/nod-ai/SHARK-Turbine/issues/145
8288
with Location.unknown():
8389
return RankedTensorType.get(dim_list, dtype)
8490
raise TypeError(f"Unsupported torch type conversion for {torch_type}")
8591

86-
def convert_torch_element_type_to_native(self, torch_type: IrType) -> IrType:
92+
def convert_torch_element_type_to_native(
93+
self, torch_type: IrType, signless: bool = True
94+
) -> IrType:
8795
# Torch uses the builtin type hierarchy of IntegerType and FloatType
8896
# to represent dtypes. These are mostly the same, but it always uses
8997
# signed IntegerTypes which we must convert to signless for the native
9098
# type system.
91-
if IntegerType.isinstance(torch_type):
92-
signed_int_type = IntegerType(torch_type)
93-
return IntegerType.get_signless(signed_int_type.width)
99+
if signless:
100+
if IntegerType.isinstance(torch_type):
101+
signed_int_type = IntegerType(torch_type)
102+
return IntegerType.get_signless(signed_int_type.width)
94103
return torch_type
95104

96105
def materialize_native_to_torch(
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# Copyright 2023 Advanced Micro Devices, Inc
2+
#
3+
# Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
# See https://llvm.org/LICENSE.txt for license information.
5+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
7+
from . import iree

python/shark_turbine/ops/iree.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
# Copyright 2023 Advanced Micro Devices, Inc
2+
#
3+
# Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
# See https://llvm.org/LICENSE.txt for license information.
5+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
7+
"""Custom ops for built-in IREE functionality."""
8+
9+
from ..support.ir_imports import (
10+
RankedTensorType,
11+
StringAttr,
12+
Value,
13+
flow_d,
14+
tensor_d,
15+
)
16+
17+
from ..runtime.op_reg import (
18+
CustomOp,
19+
KernelBuilder,
20+
KernelSelection,
21+
def_library,
22+
)
23+
24+
__all__ = [
25+
"trace",
26+
]
27+
28+
IREE_LIBRARY = def_library("iree")
29+
30+
31+
################################################################################
32+
# trace_tensor / trace_tensors
33+
# See the flow.tensor_trace op for details. In essence:
34+
# * trace_key is a name to label tensors with (intended for log filtering)
35+
# * tensor or tensors are values to log a value for
36+
################################################################################
37+
38+
39+
def _emit_tensor_trace(kb: KernelBuilder, key: str, ts: list[Value]):
40+
dynamic_dims = []
41+
for t in ts:
42+
rtt = RankedTensorType(t.type)
43+
for i in range(rtt.rank):
44+
if rtt.is_dynamic_dim(i):
45+
dynamic_dims.append(tensor_d.dim(t, kb.constant_index(i)))
46+
flow_d.TensorTraceOp(StringAttr.get(key), ts, dynamic_dims)
47+
48+
49+
@CustomOp.register(library=IREE_LIBRARY)
50+
class trace_tensor(CustomOp):
51+
signature = "trace_tensor(str trace_key, Tensor tensor) -> ()"
52+
53+
def select(self, ksel: KernelSelection):
54+
ksel.attr_str(0)
55+
ksel.arg_tensor(1)
56+
57+
def generate(self, ksel: KernelSelection, kb: KernelBuilder):
58+
_emit_tensor_trace(kb, ksel.arg_descs[0].v, [kb.arg_bindings[1]])
59+
kb.yield_results()
60+
61+
62+
@CustomOp.register(library=IREE_LIBRARY)
63+
class trace_tensors(CustomOp):
64+
signature = "trace_tensors(str trace_key, Tensor[] tensors) -> ()"
65+
66+
def select(self, ksel: KernelSelection):
67+
ksel.attr_str(0)
68+
ksel.arg_tensor_list(1)
69+
70+
def generate(self, ksel: KernelSelection, kb: KernelBuilder):
71+
ts = kb.arg_bindings[1]
72+
if len(ts) >= 1:
73+
_emit_tensor_trace(kb, ksel.arg_descs[0].v, ts)
74+
kb.yield_results()

python/shark_turbine/runtime/device.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,7 @@ def _device_import_torch_tensor_cpu(device: Device, t: torch.Tensor) -> HalBuffe
263263
memory_type=MemoryType.DEVICE_LOCAL,
264264
allowed_usage=BufferUsage.DEFAULT,
265265
device=hal_device,
266-
buffer=t.numpy(),
266+
buffer=t.detach().numpy(),
267267
element_type=element_type,
268268
)
269269
return bv

0 commit comments

Comments
 (0)