Skip to content

Commit ac3e153

Browse files
authored
[DT][CPU] Exclude pack ops with reshape producers from lowering config setting (#21675)
Without this, the reshape can break dimension tracking, creating extra false dimensions that lead to incorrect lowering configurations. See #21670. --------- Signed-off-by: Yu-Zhewen <[email protected]>
1 parent d9e71c6 commit ac3e153

File tree

2 files changed

+55
-8
lines changed

2 files changed

+55
-8
lines changed

compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3450,6 +3450,29 @@ lowerUsingDefaultPipeline(mlir::FunctionOpInterface entryPointFn) {
34503450
return setTranslationInfo(entryPointFn, translationInfo);
34513451
}
34523452

3453+
/// Returns true if the given operation should have a lowering config set.
3454+
///
3455+
/// This predicate excludes:
3456+
/// - Ops inside a `CustomOp` that already have a lowering config.
3457+
/// - `linalg.pack` ops whose producer is a `tensor.collapse_shape`,
3458+
/// as they will be lowered together into a `map_scatter` later in the
3459+
/// pipeline.
3460+
static bool shouldSetLoweringConfig(Operation *op) {
3461+
if (isa_and_nonnull<IREE::LinalgExt::CustomOp>(op->getParentOp()) &&
3462+
getLoweringConfig(op) != nullptr) {
3463+
return false;
3464+
}
3465+
3466+
if (auto packOp = dyn_cast<linalg::PackOp>(op)) {
3467+
if (isa_and_nonnull<tensor::CollapseShapeOp>(
3468+
packOp.getSource().getDefiningOp())) {
3469+
return false;
3470+
}
3471+
}
3472+
3473+
return true;
3474+
}
3475+
34533476
/// Sets the translation information to use for a dispatch region.
34543477
static LogicalResult
34553478
setTranslationInfoAndRootConfig(mlir::FunctionOpInterface entryPointFn,
@@ -3488,14 +3511,8 @@ setTranslationInfoAndRootConfig(mlir::FunctionOpInterface entryPointFn,
34883511
return failure();
34893512
}
34903513

3491-
// Avoid this for ops within a custom_op since those ops have already their
3492-
// configuration set.
3493-
auto prunedComputeOps =
3494-
llvm::to_vector(llvm::make_filter_range(computeOps, [](Operation *op) {
3495-
return !isa_and_nonnull<IREE::LinalgExt::CustomOp>(
3496-
op->getParentOp()) ||
3497-
getLoweringConfig(op) == nullptr;
3498-
}));
3514+
auto prunedComputeOps = llvm::to_vector(
3515+
llvm::make_filter_range(computeOps, shouldSetLoweringConfig));
34993516
if (failed(setLoweringConfigForComputeOps(entryPointFn, prunedComputeOps,
35003517
rootOperation))) {
35013518
return failure();

compiler/src/iree/compiler/Codegen/LLVMCPU/test/select_x86_64_lowering_strategy.mlir

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2117,3 +2117,33 @@ func.func @decode_reduction_f32(%arg0: tensor<32x262144xf16>, %arg1: tensor<32xf
21172117
// CHECK-SAME: lowering_config = #[[CONFIG0]]
21182118
// CHECK: linalg.generic
21192119
// CHECK-SAME: lowering_config = #[[CONFIG1]]
2120+
2121+
// -----
2122+
2123+
#executable_target_embedded_elf_x86_64 = #hal.executable.target<"llvm-cpu", "embedded-elf-x86_64", {cpu_features = "+avx512f", native_vector_size = 64 : i64, target_triple = "x86_64-unknown-unknown-eabi-elf"}>
2124+
func.func @attention_reshape_pack(%arg0: index, %arg1: tensor<4x2x?x32xf16>, %arg2: tensor<?x4x32xf16>, %arg3: tensor<?x4x32xf16>, %arg4: tensor<4x2x?x?xf16>) -> tensor<?x256x1x1xf16> attributes {hal.executable.target = #executable_target_embedded_elf_x86_64} {
2125+
%cst = arith.constant 0.000000e+00 : f16
2126+
%cst_0 = arith.constant 1.767580e-01 : f16
2127+
%0 = tensor.empty(%arg0) : tensor<?x4x2x32xf16>
2128+
%1 = tensor.empty(%arg0) : tensor<4x2x?x32xf16>
2129+
%2 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d4)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d5, d0, d4)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d5, d0, d3)>, affine_map<(d0, d1, d2, d3, d4, d5) -> ()>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>]} ins(%arg1, %arg2, %arg3, %cst_0, %arg4 : tensor<4x2x?x32xf16>, tensor<?x4x32xf16>, tensor<?x4x32xf16>, f16, tensor<4x2x?x?xf16>) outs(%1 : tensor<4x2x?x32xf16>) {
2130+
^bb0(%arg5: f32):
2131+
iree_linalg_ext.yield %arg5 : f32
2132+
} -> tensor<4x2x?x32xf16>
2133+
%3 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d2, d0, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%2 : tensor<4x2x?x32xf16>) outs(%0 : tensor<?x4x2x32xf16>) {
2134+
^bb0(%in: f16, %out: f16):
2135+
linalg.yield %in : f16
2136+
} -> tensor<?x4x2x32xf16>
2137+
%collapsed = tensor.collapse_shape %3 [[0], [1, 2, 3]] : tensor<?x4x2x32xf16> into tensor<?x256xf16>
2138+
%4 = tensor.empty(%arg0) : tensor<?x256x1x1xf16>
2139+
%pack = linalg.pack %collapsed padding_value(%cst : f16) outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [1, 1] into %4 : tensor<?x256xf16> -> tensor<?x256x1x1xf16>
2140+
return %pack : tensor<?x256x1x1xf16>
2141+
}
2142+
// CHECK-DAG: #[[CONFIG0:.+]] = #iree_cpu.lowering_config<distribution = [1, 1, 64, 16, 0, 0], vector_common_parallel = [1, 1, 4, 16, 0, 0], vector_reduction = [0, 0, 0, 0, 0, 32]>
2143+
// CHECK-DAG: #[[CONFIG1:.+]] = #iree_cpu.lowering_config<vector_common_parallel = [1, 1, 4, 16]>
2144+
// CHECK-NOT: #iree_cpu.lowering_config
2145+
// CHECK: func.func @attention_reshape_pack
2146+
// CHECK: iree_linalg_ext.attention
2147+
// CHECK-SAME: lowering_config = #[[CONFIG0]]
2148+
// CHECK: linalg.generic
2149+
// CHECK-SAME: lowering_config = #[[CONFIG1]]

0 commit comments

Comments
 (0)