Skip to content

Commit 6b21267

Browse files
Bump IREE to 20240119.775. (#360)
1 parent 090f359 commit 6b21267

File tree

7 files changed

+11
-12
lines changed

7 files changed

+11
-12
lines changed

python/shark_turbine/kernel/compiler/builder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,8 +128,8 @@ def py_constant_int(self, py_value) -> Value:
128128
# If coming from a stock 'int' Python type with no idea how to convert it,
129129
# there isn't much smart we can do. We conservatively treat 'index' as
130130
# reasonable.
131-
attr = IntegerAttr.get(IndexType.get(), py_value)
132-
return arith_d.constant(attr)
131+
result_type = IndexType.get()
132+
return arith_d.constant(result_type, IntegerAttr.get(result_type, py_value))
133133

134134
# Binary index arithmetic.
135135
def binary_add_index_index(self, lhs: Value, rhs: Value) -> Value:

python/shark_turbine/kernel/compiler/dispatch_codegen.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -171,8 +171,9 @@ def abi_type(binding: BindingDesc):
171171
workgroup_values = list(workgroup_builder.workload)
172172
while len(workgroup_values) < 3:
173173
with InsertionPoint(workgroup_builder.entry_block):
174+
result_type = IndexType.get()
174175
workgroup_values.append(
175-
arith_d.constant(IntegerAttr.get(IndexType.get(), 1))
176+
arith_d.constant(result_type, IntegerAttr.get(result_type, 1))
176177
)
177178
workgroup_builder.terminate(workgroup_values)
178179

@@ -226,7 +227,8 @@ def resolve(self, binding: BindingDesc) -> Value:
226227

227228
if binding.binding_type == BindingType.KERNEL_BUFFER:
228229
# Issue a subspan to get into the memref domain.
229-
zero_value = arith_d.constant(IntegerAttr.get(IndexType.get(), 0))
230+
result_type = IndexType.get()
231+
zero_value = arith_d.constant(result_type, IntegerAttr.get(result_type, 0))
230232
linear_arg_value = self._abi_value_by_reference[binding.reference]
231233
# TODO: Need to also look up dynamic symbol values.
232234
return stream_d.binding_subspan(

python/shark_turbine/kernel/compiler/vector_codegen.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,7 @@ def _(emitter: ThreadEmitter, node: fx.Node):
266266
vector_type = VectorType.get(vector_shape, element_type)
267267
pad_attr = ScalarBuilder.zero_attr(element_type)
268268
indices = cast_indices(emitter, [s.start for s in sa.slices])
269-
pad_value = arith_d.constant(pad_attr)
269+
pad_value = arith_d.constant(element_type, pad_attr)
270270
result = vector_d.transfer_read(
271271
vector_type, kb_src, indices, AffineMap.get_identity(len(indices)), pad_value
272272
)
@@ -329,7 +329,7 @@ def combiner(element_type: IrType, attrs: NodeAttrs) -> vector_d.CombiningKind:
329329
# Non-NaN propagating.
330330
# TODO: Carry a "fastmath" flag on the emitter and choose between this
331331
# and MAXIMUMF?
332-
return vector_d.CombiningKind.MAXF
332+
return vector_d.CombiningKind.MAXNUMF
333333
elif ScalarBuilder.is_integer_type(element_type):
334334
return (
335335
vector_d.CombiningKind.MAXUI
@@ -365,7 +365,7 @@ def emit_reduction(
365365
vector_type = VectorType(input.type)
366366
element_type = vector_type.element_type
367367
rank = vector_type.rank
368-
zero = arith_d.constant(ScalarBuilder.zero_attr(element_type))
368+
zero = arith_d.constant(element_type, ScalarBuilder.zero_attr(element_type))
369369
combiner = combiner_callback(element_type, attrs)
370370

371371
if len(args) == 1:

python/turbine_models/custom_models/llama-benchmark/benchmark_module.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@ def create_benchmark_vmfb(args):
3535
"--iree-llvmcpu-target-triple=x86_64-linux-gnu",
3636
"--iree-stream-resource-index-bits=64",
3737
"--iree-vm-target-index-bits=64",
38-
"--iree-codegen-check-ir-before-llvm-conversion=false",
3938
"--iree-opt-const-expr-hoisting=False",
4039
]
4140
device = args.device

python/turbine_models/custom_models/sd_inference/utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@ def compile_to_vmfb(module_str, device, target_triple, max_alloc, safe_name):
3535
"--iree-llvmcpu-target-triple=x86_64-linux-gnu",
3636
"--iree-stream-resource-index-bits=64",
3737
"--iree-vm-target-index-bits=64",
38-
"--iree-codegen-check-ir-before-llvm-conversion=false",
3938
"--iree-opt-const-expr-hoisting=False",
4039
]
4140
if device == "cpu":

python/turbine_models/custom_models/stateless_llama.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -321,7 +321,6 @@ def evict_kvcache_space(self):
321321
"--iree-llvmcpu-target-triple=x86_64-linux-gnu",
322322
"--iree-stream-resource-index-bits=64",
323323
"--iree-vm-target-index-bits=64",
324-
"--iree-codegen-check-ir-before-llvm-conversion=false",
325324
"--iree-opt-const-expr-hoisting=False",
326325
]
327326
if device == "cpu" or device == "llvm-cpu":

requirements.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,5 +7,5 @@
77
-r pytorch-cpu-requirements.txt
88
-r torchvision-requirements.txt
99

10-
iree-compiler==20231218.742
11-
iree-runtime==20231218.742
10+
iree-compiler==20240119.775
11+
iree-runtime==20240119.775

0 commit comments

Comments
 (0)