Skip to content

Commit 5d90a81

Browse files
hanhanWkeshavvinayak01
authored andcommitted
[CPU] Switch CPUDefault pipeline to use IREE::CPU::LoweringConfigAttr. (iree-org#21515)
Signed-off-by: hanhanW <[email protected]> Signed-off-by: keshavvinayak01 <[email protected]>
1 parent 823c9bc commit 5d90a81

File tree

2 files changed

+17
-6
lines changed

2 files changed

+17
-6
lines changed

compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2130,7 +2130,9 @@ setDefaultGenericOpRootConfig(mlir::FunctionOpInterface entryPointFn,
21302130
unsigned numLoops = genericOp.getNumLoops();
21312131
if (numLoops == 0) {
21322132
return setOpConfigAndEntryPointFnTranslation(
2133-
entryPointFn, genericOp, TileSizesListType{{}},
2133+
entryPointFn, genericOp,
2134+
IREE::CPU::LoweringConfigAttr::get(genericOp.getContext(),
2135+
SmallVector<NamedAttribute>()),
21342136
DispatchLoweringPassPipeline::CPUDefault);
21352137
}
21362138

@@ -2674,17 +2676,23 @@ static LogicalResult setRootConfig(mlir::FunctionOpInterface entryPointFn,
26742676
assert(!getLoweringConfig(op) && "expected lowering_config is not set");
26752677
SmallVector<int64_t> distTileSizes =
26762678
getDefaultDistributedLevelTileSizes(op, DistributionHeuristicConfig{});
2677-
TileSizesListType tileSizes = {distTileSizes};
2678-
SmallVector<int64_t> vecTileSizes = distTileSizes;
26792679

26802680
// Add an extra level of tiling.
26812681
// TODO: Limit vector tile sizes for other TilingInterface ops.
2682+
SmallVector<int64_t> vecTileSizes = distTileSizes;
26822683
if (auto linalgOp = dyn_cast<linalg::LinalgOp>(*op)) {
26832684
limitVectorTileSizes(linalgOp, vecTileSizes);
26842685
}
2685-
tileSizes.push_back(vecTileSizes);
2686+
2687+
LoweringConfigGenerator generator(op);
2688+
generator.setDistributionTileSizes(distTileSizes);
2689+
generator.setVectorTileSizes(vecTileSizes);
2690+
IREE::CPU::LoweringConfigAttr loweringConfig =
2691+
generator.generateCPULoweringConfig();
2692+
LDBG("Set lowering_config for tensor.pad op: " << loweringConfig);
26862693
return setOpConfigAndEntryPointFnTranslation(
2687-
entryPointFn, op, tileSizes, DispatchLoweringPassPipeline::CPUDefault);
2694+
entryPointFn, op, loweringConfig,
2695+
DispatchLoweringPassPipeline::CPUDefault);
26882696
}
26892697

26902698
/// Redirects to methods that set the configuration based on operation type.

compiler/src/iree/compiler/Codegen/LLVMCPU/test/select_x86_64_lowering_strategy.mlir

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -934,9 +934,12 @@ func.func @scalar() attributes {hal.executable.target = #executable_target_embed
934934
iree_tensor_ext.dispatch.tensor.store %4, %1, offsets = [], sizes = [], strides = [] : tensor<f32> -> !iree_tensor_ext.dispatch.tensor<writeonly:tensor<f32>>
935935
return
936936
}
937+
// CHECK-DAG: #[[CONFIG:.+]] = #iree_cpu.lowering_config<distribution = []>
937938
// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<pipeline = CPUDefault>
938939
// CHECK: func.func @scalar()
939940
// CHECK-SAME: translation_info = #[[TRANSLATION]]
941+
// CHECK: linalg.generic
942+
// CHECK-SAME: lowering_config = #config
940943

941944
// -----
942945

@@ -2014,7 +2017,7 @@ func.func @test_tiling_cpu_default(%arg0: tensor<256x256xi8>, %arg1: tensor<256x
20142017
%0 = linalg.quantized_matmul ins(%arg0, %arg1, %arg2, %arg3 : tensor<256x256xi8>, tensor<256x256xi8>, i32, i32) outs(%arg4 : tensor<256x256xi32>) -> tensor<256x256xi32>
20152018
return %0 : tensor<256x256xi32>
20162019
}
2017-
// CHECK-DAG: #[[CONFIG0:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[64, 64, 0], [4, 64, 0]]>
2020+
// CHECK-DAG: #[[CONFIG0:.+]] = #iree_cpu.lowering_config<distribution = [64, 64, 0], vector_common_parallel = [4, 64, 0], vector_reduction = [0, 0, 0]>
20182021
// CHECK-DAG: #[[TRANSLATION_INFO]] = #iree_codegen.translation_info<pipeline = CPUDefault>
20192022
// CHECK: func @test_tiling_cpu_default(
20202023
// CHECK-SAME: translation_info = #[[TRANSLATION_INFO]]

0 commit comments

Comments
 (0)