Skip to content

Commit eeda7ca

Browse files
authored
[Codegen][Encoding] Fix generic op materialization with 0D tensors (#21545)
The generic op materialization for data tiling will fail if any of the operands resolve to identity layout, when the output operand does not, but 0D tensors will always resolve to identity layout in data tiling. This PR adds a special case for 0D tensors to fix the bug. Fixes #21540 Signed-off-by: Max Dawkins <[email protected]>
1 parent 3160987 commit eeda7ca

File tree

2 files changed

+50
-1
lines changed

2 files changed

+50
-1
lines changed

compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingPatterns.cpp

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,16 @@ static FailureOr<Operation *> lowerGenericOpWithEncoding(
269269

270270
SmallVector<AffineMap> packedIndexingMaps;
271271
for (OpOperand *inputOperand : genericOp.getDpsInputOperands()) {
272+
AffineMap inputMap = genericOp.getMatchingIndexingMap(inputOperand);
273+
// Special case for 0D inputs. They will resolve to identity layout, so
274+
// skip the logic to compute the packed indexing map.
275+
if (inputMap.getNumResults() == 0) {
276+
auto packedInputMap = AffineMap::get(
277+
/*dimCount=*/iteratorTypes.size(), /*symbolCount=*/0, {},
278+
rewriter.getContext());
279+
packedIndexingMaps.push_back(packedInputMap);
280+
continue;
281+
}
272282
// Step 2: Retrieve the encoding for every input operand and perform the
273283
// outer dimension permutation, inner dimension expansion and permutation,
274284
// swizzle expansion and swizzle permutation.
@@ -310,7 +320,6 @@ static FailureOr<Operation *> lowerGenericOpWithEncoding(
310320
}
311321
ArrayRef<int64_t> innerDimsPos = materializeEncodingInfo.innerDimsPos;
312322
ArrayRef<int64_t> outerDimsPerm = materializeEncodingInfo.outerDimsPerm;
313-
AffineMap inputMap = genericOp.getMatchingIndexingMap(inputOperand);
314323
// Permute result dims to the input packed domain, and map dims to the
315324
// output packed domain.
316325
SmallVector<int64_t> packedResultDims = llvm::map_to_vector(

compiler/src/iree/compiler/Codegen/Common/test/materialize_encoding_x86_64.mlir

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2220,3 +2220,43 @@ func.func @set_encoding_transpose_multi_result() attributes {
22202220
// CHECK: %[[PACK:.+]] = linalg.pack %[[TRANSPOSE]]
22212221
// CHECK: iree_tensor_ext.dispatch.tensor.store %[[TRANSPOSE]], %[[RESULT_BINDING]]
22222222
// CHECK: iree_tensor_ext.dispatch.tensor.store %[[PACK]], %[[RESULT_BINDING1]]
2223+
2224+
// -----
2225+
2226+
#executable_target_xyz = #hal.executable.target<"llvm-cpu", "xyz", {target_triple = "x86_64-xyz-xyz", iree.encoding.resolver = #iree_cpu.cpu_encoding_resolver<>}>
2227+
#map = affine_map<(d0, d1, d2) -> (d0, d2)>
2228+
#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
2229+
#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
2230+
#map3 = affine_map<(d0, d1) -> ()>
2231+
#map4 = affine_map<(d0, d1) -> (d0, d1)>
2232+
#pipeline_layout = #hal.pipeline.layout<bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>
2233+
#encoding = #iree_encoding.encoding<operand_index = 2 : index, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [#map, #map1, #map2], iteration_sizes = [2, 4, 3]>
2234+
#encoding1 = #iree_encoding.encoding<operand_index = 2 : index, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [#map, #map1, [#map2, #map3]], iteration_sizes = [2, 4, 3]>
2235+
func.func @generic_with_0d_tensor() attributes {hal.executable.target = #executable_target_xyz} {
2236+
%cst = arith.constant 0.000000e+00 : f32
2237+
%c0 = arith.constant 0 : index
2238+
%0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : !iree_tensor_ext.dispatch.tensor<readonly:tensor<2x4xf32, #encoding>>
2239+
%1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : !iree_tensor_ext.dispatch.tensor<readonly:tensor<f32>>
2240+
%2 = hal.interface.binding.subspan layout(#pipeline_layout) binding(2) alignment(64) offset(%c0) flags(Indirect) : !iree_tensor_ext.dispatch.tensor<writeonly:tensor<2x4xf32, #encoding>>
2241+
%3 = iree_tensor_ext.dispatch.tensor.load %0, offsets = [0, 0], sizes = [2, 4], strides = [1, 1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<2x4xf32, #encoding>> -> tensor<2x4xf32, #encoding>
2242+
%4 = iree_tensor_ext.dispatch.tensor.load %1, offsets = [], sizes = [], strides = [] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<f32>> -> tensor<f32>
2243+
%5 = iree_encoding.set_encoding %4 : tensor<f32> -> tensor<f32, #encoding1>
2244+
%6 = tensor.empty() : tensor<2x4xf32, #encoding>
2245+
%7 = linalg.fill ins(%cst : f32) outs(%6 : tensor<2x4xf32, #encoding>) -> tensor<2x4xf32, #encoding>
2246+
%8 = linalg.generic {indexing_maps = [#map4, #map3, #map4], iterator_types = ["parallel", "parallel"]} ins(%3, %5 : tensor<2x4xf32, #encoding>, tensor<f32, #encoding1>) outs(%7 : tensor<2x4xf32, #encoding>) {
2247+
^bb0(%in: f32, %in_0: f32, %out: f32):
2248+
%9 = arith.addf %in, %in_0 : f32
2249+
linalg.yield %9 : f32
2250+
} -> tensor<2x4xf32, #encoding>
2251+
iree_tensor_ext.dispatch.tensor.store %8, %2, offsets = [0, 0], sizes = [2, 4], strides = [1, 1] : tensor<2x4xf32, #encoding> -> !iree_tensor_ext.dispatch.tensor<writeonly:tensor<2x4xf32, #encoding>>
2252+
return
2253+
}
2254+
2255+
// CHECK-LABEL: func.func @generic_with_0d_tensor
2256+
// CHECK-DAG: %[[INPUT_BINDING:.+]] = hal.interface.binding.subspan {{.*}} binding(0)
2257+
// CHECK-DAG: %[[INPUT_0D_BINDING:.+]] = hal.interface.binding.subspan {{.*}} binding(1)
2258+
// CHECK-DAG: %[[RESULT_BINDING:.+]] = hal.interface.binding.subspan {{.*}} binding(2)
2259+
// CHECK-DAG: %[[INPUT:.+]] = iree_tensor_ext.dispatch.tensor.load %[[INPUT_BINDING]]
2260+
// CHECK-DAG: %[[INPUT_0D:.+]] = iree_tensor_ext.dispatch.tensor.load %[[INPUT_0D_BINDING]]
2261+
// CHECK: %[[GENERIC:.+]] = linalg.generic {{.*}} ins(%[[INPUT]], %[[INPUT_0D]] : tensor<1x1x2x4xf32>, tensor<f32>)
2262+
// CHECK: iree_tensor_ext.dispatch.tensor.store %[[GENERIC]], %[[RESULT_BINDING]]

0 commit comments

Comments
 (0)