Skip to content

Commit b2b3ed1

Browse files
authored
[DT] Fix a bug in encoding propagation when there are scalar inputs. (#21596)
A linalg op can take scalars as inputs, and the encoding propagation can ignore it. It is similar to 0-D tensor, but a scalar is used. The failure is from the result of QuantizedMatmulToMatmul pass. We can switch to 0-D tensor as a fix. However, it is a legal linalg op, so supporting the propagation on scalars is a better fix. --------- Signed-off-by: hanhanW <[email protected]>
1 parent 3256168 commit b2b3ed1

File tree

2 files changed

+41
-1
lines changed

2 files changed

+41
-1
lines changed

compiler/src/iree/compiler/DispatchCreation/test/hoist_encoding_ops.mlir

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,41 @@ util.func public @propagate_unset_encoding_through_generic(%arg0: tensor<?x4096x
327327

328328
// -----
329329

330+
#map = affine_map<(d0, d1, d2) -> (d0, d2)>
331+
#map1 = affine_map<(d0, d1, d2) -> (d1, d2)>
332+
#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
333+
#map3 = affine_map<(d0, d1) -> (d0, d1)>
334+
#map4 = affine_map<(d0, d1) -> ()>
335+
#encoding = #iree_encoding.encoding<operand_index = 2 : index, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [#map, #map1, #map2]>
336+
util.func public @propagate_unset_encoding_through_generic_with_scalar(%arg0: tensor<4096x?xf32, #encoding>, %arg1: f32, %arg2: index) -> tensor<4096x?xf32> {
337+
%0 = flow.dispatch.region -> (tensor<4096x?xf32>{%arg2}) {
338+
%1 = iree_encoding.unset_encoding %arg0 : tensor<4096x?xf32, #encoding> -> tensor<4096x?xf32>{%arg2}
339+
%2 = tensor.empty(%arg2) : tensor<4096x?xf32>
340+
%3 = linalg.generic {indexing_maps = [#map3, #map4, #map3], iterator_types = ["parallel", "parallel"]} ins(%1, %arg1 : tensor<4096x?xf32>, f32) outs(%2 : tensor<4096x?xf32>) {
341+
^bb0(%in: f32, %in_0: f32, %out: f32):
342+
%4 = arith.mulf %in, %in_0 : f32
343+
linalg.yield %4 : f32
344+
} -> tensor<4096x?xf32>
345+
flow.return %3 : tensor<4096x?xf32>
346+
}
347+
util.return %0 : tensor<4096x?xf32>
348+
}
349+
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
350+
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)>
351+
// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
352+
// CHECK-DAG: #[[$ENCODING:.+]] = #iree_encoding.encoding<operand_index = 2 : index, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]]>
353+
// CHECK-LABEL: @propagate_unset_encoding_through_generic_with_scalar(
354+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
355+
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
356+
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]
357+
// CHECK: %{{.+}} = flow.dispatch.region -> (tensor<4096x?xf32>{%[[ARG2]]}
358+
// CHECK: %[[GENERIC:.+]] = linalg.generic
359+
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]]
360+
// CHECK: %[[UNSET_ENCODING:.+]] = iree_encoding.unset_encoding %[[GENERIC]] : tensor<4096x?xf32, #[[$ENCODING]]> -> tensor<4096x?xf32>{%[[ARG2]]}
361+
// CHECK: return %[[UNSET_ENCODING]]
362+
363+
// -----
364+
330365
#map = affine_map<(d0, d1, d2) -> (d0, d2)>
331366
#map1 = affine_map<(d0, d1, d2) -> (d1, d2)>
332367
#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>

compiler/src/iree/compiler/ExternalInterfaces/EncodingExternalModels.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,12 @@ struct GenericOpPropagationInterface
219219
}
220220

221221
auto operandType =
222-
cast<RankedTensorType>(operand->get().getType());
222+
dyn_cast<RankedTensorType>(operand->get().getType());
223+
if (!operandType) {
224+
// Scalar types do not need encodings.
225+
encodedOperands.push_back(operand->get());
226+
continue;
227+
}
223228
auto resType = RankedTensorType::get(
224229
operandType.getShape(), operandType.getElementType(),
225230
encoding);

0 commit comments

Comments
 (0)