Skip to content

Commit 36ec590

Browse files
Make tensor updates to a global work. (#64)
The global and immediate Ir{Scalar|Tensor} hierarchy was unfortunately divergent. This unifies it and makes the test pass. --------- Co-authored-by: Stella Laurenzo <[email protected]>
1 parent d51df8d commit 36ec590

File tree

6 files changed

+240
-91
lines changed

6 files changed

+240
-91
lines changed

python/shark_turbine/aot/builtins/jittable.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,9 @@
4141

4242
from ..support.procedural import (
4343
CallableIntrinsic,
44+
IrImmediateTensor,
45+
IrTensor,
4446
IrTrace,
45-
IrValueTensor,
4647
MaterializedGlobal,
4748
)
4849

@@ -98,7 +99,7 @@ def resolver(py_value: Any, gni: GraphNodeImporter) -> Optional[Value]:
9899
# Emit a global load and conversion.
99100
vtensor_type = gni._cc.tensor_to_vtensor_type(py_value)
100101
loaded_value = util_d.GlobalLoadOp(
101-
materialized_global.global_type, materialized_global.symbol_name
102+
materialized_global.ir_type, materialized_global.symbol_name
102103
).result
103104
converted_value = Operation.create(
104105
"torch_c.from_builtin_tensor",
@@ -249,7 +250,7 @@ def resolve_call(self, proc_trace: IrTrace, *py_args, **py_kwargs):
249250
flat_py_results = []
250251
for ir_result, pytorch_meta in zip(flat_ir_results, pytorch_meta_results):
251252
if isinstance(pytorch_meta, TensorMetadata):
252-
flat_py_results.append(IrValueTensor(ir_result, pytorch_meta.dtype))
253+
flat_py_results.append(IrImmediateTensor(ir_result, pytorch_meta.dtype))
253254
else:
254255
raise TypeError(
255256
f"Unknown PyTorch->IREE value mapping for jittable result: {pytorch_meta}->{ir_result}"
@@ -264,7 +265,7 @@ def resolve_call(self, proc_trace: IrTrace, *py_args, **py_kwargs):
264265
return flat_py_results
265266

266267
def _split_py_arg(self, arg) -> Tuple[Value, Any]:
267-
if isinstance(arg, IrValueTensor):
268+
if isinstance(arg, IrTensor):
268269
return arg.ir_value, arg._to_meta_tensor()
269270

270271
raise TypeError(f"Unsupported argument to jittable: {arg}")

python/shark_turbine/aot/support/procedural/base.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,17 @@ def resolve_assignment(self, proc_trace: "IrTrace", ir_values: Sequence[Value]):
112112
f"Cannot use {self} as the target of an assignment in a procedural function"
113113
)
114114

115+
# Helpers for accessing the ir_value within the current trace.
116+
@property
117+
def ir_values(self) -> Sequence[Value]:
118+
return self.resolve_ir_values(current_ir_trace())
119+
120+
@property
121+
def ir_value(self) -> Value:
122+
values = self.ir_values
123+
assert len(values) == 1, "Expected arity one intrinsic"
124+
return values[0]
125+
115126

116127
class CallableIntrinsic(Intrinsic):
117128
"""Intrinsic subclass that supports calls.
@@ -182,7 +193,7 @@ def __repr__(self):
182193
return f"AbstractTensor({', '.join(str(s) for s in self.size)}, dtype={self.dtype})"
183194

184195
def create_intrinsic(self, ir_value: Value) -> Intrinsic:
185-
return IrValueTensor(ir_value, self.dtype)
196+
return IrImmediateTensor(ir_value, self.dtype)
186197

187198
def get_ir_type(self, builder: ModuleBuilder) -> IrType:
188199
element_type = builder.torch_dtype_to_iree_type(self.dtype)
@@ -213,7 +224,7 @@ def __repr__(self):
213224
return f"AbstractScalar({self.label})"
214225

215226
def create_intrinsic(self, ir_value: Value) -> Intrinsic:
216-
return IrValueScalar(ir_value)
227+
return IrImmediateScalar(ir_value)
217228

218229
def get_ir_type(self, builder: ModuleBuilder) -> IrType:
219230
with builder.context:
@@ -249,6 +260,6 @@ def abstractify(tree):
249260

250261
# Circular iports.
251262
from .primitives import (
252-
IrValueScalar,
253-
IrValueTensor,
263+
IrImmediateScalar,
264+
IrImmediateTensor,
254265
)

python/shark_turbine/aot/support/procedural/globals.py

Lines changed: 93 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,18 @@
3535
)
3636

3737
from .base import (
38-
AbstractTypedef,
38+
AbstractScalar,
39+
AbstractTensor,
3940
Intrinsic,
4041
IrTrace,
4142
current_ir_trace,
4243
)
4344

45+
from .primitives import (
46+
IrScalar,
47+
IrTensor,
48+
)
49+
4450
###############################################################################
4551
# Globals
4652
###############################################################################
@@ -121,17 +127,18 @@ def track(self, module_builder: ModuleBuilder, export_namespace: str) -> Any:
121127
initialize=self._initialize,
122128
mutable=self._mutable,
123129
)
124-
mapping.value = MaterializedGlobal(
130+
mapping.value = IrGlobalTensor(
125131
fq_name,
126132
self,
127133
symbol_name=actual_symbol_name,
128134
global_op=global_op,
129135
global_type=global_type,
136+
dtype=value.dtype,
130137
)
131138
logger.debug("TRACK NEW TENSOR(%s): %r", fq_name, mapping)
132139
flat_globals.append(mapping.value)
133140
continue
134-
elif isinstance(value, AbstractTypedef):
141+
elif isinstance(value, AbstractTensor):
135142
global_type = value.get_ir_type(module_builder)
136143
(actual_symbol_name, global_op,) = module_builder.create_typed_global(
137144
f"_{fq_name}",
@@ -140,7 +147,26 @@ def track(self, module_builder: ModuleBuilder, export_namespace: str) -> Any:
140147
mutable=self._mutable,
141148
)
142149
flat_globals.append(
143-
MaterializedGlobal(
150+
IrGlobalTensor(
151+
fq_name,
152+
self,
153+
symbol_name=actual_symbol_name,
154+
global_op=global_op,
155+
global_type=global_type,
156+
dtype=value.dtype,
157+
)
158+
)
159+
continue
160+
elif isinstance(value, AbstractScalar):
161+
global_type = value.get_ir_type(module_builder)
162+
(actual_symbol_name, global_op,) = module_builder.create_typed_global(
163+
f"_{fq_name}",
164+
global_type,
165+
initialize=self._initialize,
166+
mutable=self._mutable,
167+
)
168+
flat_globals.append(
169+
IrGlobalScalar(
144170
fq_name,
145171
self,
146172
symbol_name=actual_symbol_name,
@@ -158,8 +184,14 @@ def track(self, module_builder: ModuleBuilder, export_namespace: str) -> Any:
158184
return LiveGlobalCollectionProxy(tree_globals)
159185

160186

161-
class MaterializedGlobal(Intrinsic):
162-
"""Associates a (possibly) materialized global with a name hint and info for the aggregate it is part of."""
187+
class MaterializedGlobal:
188+
"""Tags an Ir* that is duck-typed as a global."""
189+
190+
...
191+
192+
193+
class IrGlobalScalar(IrScalar, MaterializedGlobal):
194+
"""An IrScalar that is loaded from a global and associated with its aggregate."""
163195

164196
__slots__ = [
165197
"global_op",
@@ -178,15 +210,65 @@ def __init__(
178210
global_op: Operation,
179211
global_type: IrType,
180212
):
213+
super().__init__(global_type)
214+
self.info = info
215+
self.export_name = export_name
216+
self.symbol_name = symbol_name
217+
self.global_op = global_op
218+
219+
def resolve_ir_values(self, trace: IrTrace) -> Sequence[Value]:
220+
with trace.loc, trace.ip:
221+
value = util_d.GlobalLoadOp(self.ir_type, self.symbol_name).result
222+
return [value]
223+
224+
def resolve_assignment(self, proc_trace: "IrTrace", ir_values: Sequence[Value]):
225+
if len(ir_values) != 1:
226+
raise ValueError(
227+
f"Can only assign a single value to a global. Got {len(ir_values)}"
228+
)
229+
source_ir_type = ir_values[0].type
230+
if source_ir_type != self.ir_type:
231+
raise TypeError(
232+
f"Cannot assign to a global with a different type: {self.ir_type} != {source_ir_type}"
233+
)
234+
with proc_trace.loc, proc_trace.ip:
235+
util_d.GlobalStoreOp(ir_values[0], self.symbol_name)
236+
237+
def __repr__(self):
238+
return (
239+
f"<IrGlobalScalar {self.export_name} = {self.symbol_name}:{self.ir_type}>"
240+
)
241+
242+
243+
class IrGlobalTensor(IrTensor, MaterializedGlobal):
244+
"""An IrScalar that is loaded from a global and associated with its aggregate."""
245+
246+
__slots__ = [
247+
"global_op",
248+
"info",
249+
"export_name",
250+
"symbol_name",
251+
]
252+
253+
def __init__(
254+
self,
255+
export_name: str,
256+
info: GlobalsDef,
257+
*,
258+
symbol_name: str,
259+
global_op: Operation,
260+
global_type: IrType,
261+
dtype: torch.dtype,
262+
):
263+
super().__init__(global_type, dtype)
181264
self.info = info
182265
self.export_name = export_name
183266
self.symbol_name = symbol_name
184267
self.global_op = global_op
185-
self.global_type = global_type
186268

187269
def resolve_ir_values(self, trace: IrTrace) -> Sequence[Value]:
188270
with trace.loc, trace.ip:
189-
value = util_d.GlobalLoadOp(self.global_type, self.symbol_name).result
271+
value = util_d.GlobalLoadOp(self.ir_type, self.symbol_name).result
190272
return [value]
191273

192274
def resolve_assignment(self, proc_trace: "IrTrace", ir_values: Sequence[Value]):
@@ -195,12 +277,12 @@ def resolve_assignment(self, proc_trace: "IrTrace", ir_values: Sequence[Value]):
195277
f"Can only assign a single value to a global. Got {len(ir_values)}"
196278
)
197279
source_ir_type = ir_values[0].type
198-
if source_ir_type != self.global_type:
280+
if source_ir_type != self.ir_type:
199281
raise TypeError(
200-
f"Cannot assign to a global with a different type: {self.global_type} != {source_ir_type}"
282+
f"Cannot assign to a global with a different type: {self.ir_type} != {source_ir_type}"
201283
)
202284
with proc_trace.loc, proc_trace.ip:
203285
util_d.GlobalStoreOp(ir_values[0], self.symbol_name)
204286

205287
def __repr__(self):
206-
return f"<MaterializedGlobal {self.export_name} = {self.symbol_name}:{self.global_type}>"
288+
return f"<MaterializedGlobal {self.export_name} = {self.symbol_name}:{self.ir_type}>"

python/shark_turbine/aot/support/procedural/iree_emitter.py

Lines changed: 25 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -27,15 +27,20 @@
2727

2828
from .base import (
2929
Intrinsic,
30-
IrValueTensor,
31-
IrValueScalar,
3230
current_ir_trace,
3331
ShapedTypeDynamicSizeSentinel,
3432
)
3533

36-
BuildableScalarValue = Union[IrValueScalar, Value]
34+
from .primitives import (
35+
IrScalar,
36+
IrImmediateScalar,
37+
IrTensor,
38+
IrImmediateTensor,
39+
)
40+
41+
BuildableScalarValue = Union[IrScalar, Value]
3742
BuildableTensorDimDecl = Union[int, Value]
38-
BuildableTensorType = IrValueTensor
43+
BuildableTensorType = IrTensor
3944
BuildableIndexType = Union[BuildableScalarValue, int]
4045
BuildableIndexLengthType = Union[
4146
BuildableTensorDimDecl, Tuple[BuildableTensorDimDecl, BuildableTensorDimDecl]
@@ -51,8 +56,8 @@ def cast_scalar_value(x: BuildableScalarValue) -> Value:
5156
return x
5257

5358

54-
def cast_tensor_value(x: BuildableTensorType) -> IrValueTensor:
55-
assert isinstance(x, IrValueTensor), f"Expected a tensor but got {type(x)}"
59+
def cast_tensor_value(x: BuildableTensorType) -> IrTensor:
60+
assert isinstance(x, IrTensor), f"Expected a tensor but got {type(x)}"
5661
return x
5762

5863

@@ -126,16 +131,16 @@ def wrapper(*args, **kwargs):
126131

127132
class IREEEmitter:
128133
@emitter
129-
def tensor_dim(self, source: BuildableTensorType, index: int) -> "IrValueScalar":
134+
def tensor_dim(self, source: BuildableTensorType, index: int) -> "IrScalar":
130135
"""Gets the dimension size of a tensor at a static position."""
131136
source = cast_tensor_value(source)
132137
index = cast_static_bounded_index(index, 0, source.rank - 1)
133-
return IrValueScalar(source.get_dim_value(index))
138+
return IrImmediateScalar(source.get_dim_value(index))
134139

135140
@emitter
136141
def tensor_empty(
137142
self, *dims: BuildableTensorDimDecl, dtype: torch.dtype = torch.float32
138-
) -> IrValueTensor:
143+
) -> IrTensor:
139144
"""Constructs a tensor with uninitialized values.
140145
141146
TODO: Support an IREE/raw element type in addition to the torch dtype.
@@ -147,14 +152,14 @@ def tensor_empty(
147152
raise ValueError(f"Could not map Torch dtype {dtype} to an IREE type")
148153
tensor_type = RankedTensorType.get(dim_decls, element_type)
149154
raw_tensor = flow_d.TensorEmptyOp(tensor_type, dyn_dim_values).result
150-
result = IrValueTensor(raw_tensor, dtype=dtype)
155+
result = IrImmediateTensor(raw_tensor, dtype=dtype)
151156
result.set_dynamic_dim_values(dyn_dim_values)
152157
return result
153158

154159
@emitter
155160
def tensor_reshape(
156161
self, source: BuildableTensorType, *result_dims: BuildableTensorDimDecl
157-
) -> "IrValueTensor":
162+
) -> "IrTensor":
158163
constant_cache: Dict[int, Value] = {}
159164
source = cast_tensor_value(source)
160165
result_dim_decls, result_dynamic_dims = cast_tensor_dim_decl(result_dims)
@@ -167,14 +172,14 @@ def tensor_reshape(
167172
source.get_only_dynamic_dim_values(constant_cache=constant_cache),
168173
result_dynamic_dims,
169174
).result
170-
result = IrValueTensor(result_value, dtype=source.dtype)
175+
result = IrImmediateTensor(result_value, dtype=source.dtype)
171176
result.set_dynamic_dim_values(result_dynamic_dims)
172177
return result
173178

174179
@emitter
175180
def tensor_slice(
176181
self, source: BuildableTensorType, *indices: BuildableSliceType
177-
) -> "IrValueTensor":
182+
) -> "IrTensor":
178183
"""Extracts a slice of a tensor.
179184
180185
The given indices must match the rank of the source and each index is
@@ -246,7 +251,7 @@ def tensor_slice(
246251
length_values,
247252
result_dynamic_dims,
248253
).result
249-
result = IrValueTensor(result_value, dtype=source.dtype)
254+
result = IrImmediateTensor(result_value, dtype=source.dtype)
250255
result.set_dynamic_dim_values(result_dynamic_dims)
251256
return result
252257

@@ -256,7 +261,7 @@ def tensor_update(
256261
target: BuildableTensorType,
257262
update: BuildableTensorType,
258263
*start_indices: BuildableIndexType,
259-
) -> "IrValueTensor":
264+
) -> "IrTensor":
260265
"""Applies an update to a target at start_indices and returns the mutated target."""
261266
constant_cache: Dict[int, Value] = {}
262267
target = cast_tensor_value(target)
@@ -278,7 +283,7 @@ def tensor_update(
278283
update.ir_value,
279284
update_dynamic_dims,
280285
).result
281-
result = IrValueTensor(result_value, target.dtype)
286+
result = IrImmediateTensor(result_value, target.dtype)
282287
result.set_dynamic_dim_values(target_dynamic_dims)
283288
return result
284289

@@ -288,7 +293,7 @@ def tensor_splat(
288293
*dims: BuildableTensorDimDecl,
289294
value: BuildableScalarValue,
290295
dtype: torch.dtype,
291-
) -> "IrValueTensor":
296+
) -> "IrTensor":
292297
# TODO: Type infer the dtype if missing.
293298
dim_decls, dyn_dim_values = cast_tensor_dim_decl(dims)
294299
try:
@@ -302,7 +307,7 @@ def tensor_splat(
302307
)
303308
tensor_type = RankedTensorType.get(dim_decls, element_type)
304309
raw_tensor = flow_d.TensorSplatOp(tensor_type, value, dyn_dim_values).result
305-
result = IrValueTensor(raw_tensor, dtype=dtype)
310+
result = IrImmediateTensor(raw_tensor, dtype=dtype)
306311
result.set_dynamic_dim_values(dyn_dim_values)
307312
return result
308313

@@ -314,6 +319,6 @@ def tensor_trace(self, key: str, *ts: BuildableTensorType):
314319

315320
# Circular imports to resolve typing.
316321
from .primitives import (
317-
IrValueScalar,
318-
IrValueTensor,
322+
IrScalar,
323+
IrTensor,
319324
)

0 commit comments

Comments
 (0)