Skip to content

Commit 324be39

Browse files
authored
[Codegen] Fix multiple function support in materialize user configs (iree-org#21227)
When verifying symbol imports, it was returning from the pass instead of continuing to the next function, causing us to only apply user configs to the first function.
1 parent fddefb8 commit 324be39

File tree

2 files changed

+22
-5
lines changed

2 files changed

+22
-5
lines changed

compiler/src/iree/compiler/Codegen/Common/MaterializeUserConfigs.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,7 @@ struct MaterializeUserConfigsPass final
230230
IREE::Codegen::TranslationInfoAttr translationInfo =
231231
getTranslationInfo(funcOp);
232232
if (translationInfo) {
233-
return;
233+
continue;
234234
}
235235

236236
/// First, apply all user configs.
@@ -256,13 +256,13 @@ struct MaterializeUserConfigsPass final
256256
translationInfo.getDispatchLoweringPassPipeline() !=
257257
IREE::Codegen::DispatchLoweringPassPipeline::
258258
TransformDialectCodegen) {
259-
return;
259+
continue;
260260
}
261261

262262
std::optional<SymbolRefAttr> strategyName =
263263
translationInfo.getCodegenSpec();
264264
if (!strategyName || *strategyName == SymbolRefAttr()) {
265-
return;
265+
continue;
266266
}
267267

268268
/// If we have a symbol, verify the existence of the symbol within the

compiler/src/iree/compiler/Codegen/Common/test/materialize_user_configs.mlir

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,20 @@
1010
]>
1111
#compilation = #iree_codegen.compilation_info<lowering_config = #config, translation_info = #translation>
1212
module {
13-
func.func @preset_config() attributes {hal.executable.target = #executable_target_system_elf_x86_64_} {
13+
func.func @preset_config_0() attributes {hal.executable.target = #executable_target_system_elf_x86_64_} {
14+
%cst = arith.constant 0.000000e+00 : f32
15+
%0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) : !iree_tensor_ext.dispatch.tensor<readonly:tensor<128x256xf32>>
16+
%1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) : !iree_tensor_ext.dispatch.tensor<readonly:tensor<256x512xf32>>
17+
%2 = hal.interface.binding.subspan layout(#pipeline_layout) binding(2) : !iree_tensor_ext.dispatch.tensor<writeonly:tensor<128x512xf32>>
18+
%3 = iree_tensor_ext.dispatch.tensor.load %0, offsets = [0, 0], sizes = [128, 256], strides = [1, 1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<128x256xf32>> -> tensor<128x256xf32>
19+
%4 = iree_tensor_ext.dispatch.tensor.load %1, offsets = [0, 0], sizes = [256, 512], strides = [1, 1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<256x512xf32>> -> tensor<256x512xf32>
20+
%5 = tensor.empty() : tensor<128x512xf32>
21+
%6 = linalg.fill ins(%cst : f32) outs(%5 : tensor<128x512xf32>) -> tensor<128x512xf32>
22+
%7 = linalg.matmul {compilation_info = #compilation} ins(%3, %4 : tensor<128x256xf32>, tensor<256x512xf32>) outs(%6 : tensor<128x512xf32>) -> tensor<128x512xf32>
23+
iree_tensor_ext.dispatch.tensor.store %7, %2, offsets = [0, 0], sizes = [128, 512], strides = [1, 1] : tensor<128x512xf32> -> !iree_tensor_ext.dispatch.tensor<writeonly:tensor<128x512xf32>>
24+
return
25+
}
26+
func.func @preset_config_1() attributes {hal.executable.target = #executable_target_system_elf_x86_64_} {
1427
%cst = arith.constant 0.000000e+00 : f32
1528
%0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) : !iree_tensor_ext.dispatch.tensor<readonly:tensor<128x256xf32>>
1629
%1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) : !iree_tensor_ext.dispatch.tensor<readonly:tensor<256x512xf32>>
@@ -27,7 +40,11 @@ module {
2740

2841
// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[64, 64, 0], [32, 32, 0], [0, 0, 32], [0, 0, 0]]>
2942
// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<pipeline = CPUDoubleTilingExpert>
30-
// CHECK: func.func @preset_config()
43+
// CHECK: func.func @preset_config_0()
44+
// CHECK-SAME: translation_info = #[[TRANSLATION]]
45+
// CHECK: linalg.matmul
46+
// CHECK-SAME: lowering_config = #[[CONFIG]]
47+
// CHECK: func.func @preset_config_1()
3148
// CHECK-SAME: translation_info = #[[TRANSLATION]]
3249
// CHECK: linalg.matmul
3350
// CHECK-SAME: lowering_config = #[[CONFIG]]

0 commit comments

Comments
 (0)