Skip to content

Commit 8ba9f68

Browse files
authored
[ROCM] Fix redefinition of symbol error for including tensor ukernels (iree-org#21780)
Fix for when multiple functions require the same tensor ukernel to be included. Without this PR, this leads to a symbol redefinition error. Signed-off-by: Jorn Tuyls <[email protected]>
1 parent 33e2146 commit 8ba9f68

File tree

2 files changed

+18
-2
lines changed

2 files changed

+18
-2
lines changed

compiler/plugins/target/ROCM/Dialect/ROCM/Transforms/ApplyBuiltinPDLPatterns.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -299,15 +299,15 @@ class ApplyBuiltinPDLPatternsDriverPass final
299299
MLIRContext *ctx = moduleOp.getContext();
300300
auto rocmDialect = ctx->getOrLoadDialect<IREE::ROCM::ROCMDialect>();
301301
SmallVector<FunctionOpInterface> ukernelFunctions;
302+
llvm::SmallDenseSet<StringRef> ukernelSymbols;
302303
auto res = moduleOp.walk([&](Operation *op) {
303304
auto builtinName =
304305
dyn_cast_or_null<StringAttr>(op->getAttr(kBuiltinName));
305306
auto ukernelDesc = getUKernelDescriptor(op);
306307
if (!builtinName || !ukernelDesc) {
307308
return WalkResult::advance();
308309
}
309-
if (moduleOp->hasAttr(ukernelDesc.getUkernelName())) {
310-
// Avoid parsing and serializing the same ukernel again and again.
310+
if (ukernelSymbols.contains(ukernelDesc.getUkernelName())) {
311311
return WalkResult::advance();
312312
}
313313
std::optional<StringRef> maybeBuiltin =
@@ -335,6 +335,7 @@ class ApplyBuiltinPDLPatternsDriverPass final
335335
funcOp->remove();
336336
ukernelFunctions.push_back(funcOp);
337337
op->removeAttr(kBuiltinName);
338+
ukernelSymbols.insert(ukernelDesc.getUkernelName());
338339
return WalkResult::advance();
339340
});
340341
if (res.wasInterrupted()) {

compiler/plugins/target/ROCM/Dialect/ROCM/Transforms/test/apply_builtin_ukernel_pdl_patterns_driver.mlir

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,21 @@ module attributes {
3131
} -> tensor<1x128x1024xf32>
3232
return %2 : tensor<1x128x1024xf32>
3333
}
34+
// Check that a second function requiring the same ukernel doesn't lead to a 'redefinition of symbol named ...' error.
35+
func.func @matmul_f8_medium_expanded_2(%arg0: tensor<1x128x4096xf8E4M3FNUZ>, %arg1: tensor<1024x4096xf8E4M3FNUZ>) -> tensor<1x128x1024xf32> {
36+
%cst = arith.constant 0.000000e+00 : f32
37+
%0 = tensor.empty() : tensor<1x128x1024xf32>
38+
%1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<1x128x1024xf32>) -> tensor<1x128x1024xf32>
39+
%2 = linalg.generic {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<1x128x4096xf8E4M3FNUZ>, tensor<1024x4096xf8E4M3FNUZ>) outs(%1 : tensor<1x128x1024xf32>) {
40+
^bb0(%in: f8E4M3FNUZ, %in_4: f8E4M3FNUZ, %out: f32):
41+
%12 = arith.extf %in : f8E4M3FNUZ to f32
42+
%13 = arith.extf %in_4 : f8E4M3FNUZ to f32
43+
%14 = arith.mulf %12, %13 : f32
44+
%15 = arith.addf %out, %14 : f32
45+
linalg.yield %15 : f32
46+
} -> tensor<1x128x1024xf32>
47+
return %2 : tensor<1x128x1024xf32>
48+
}
3449
}
3550
// CHECK-LABEL: util.func private @pingpong_medium_f8_expanded
3651
// CHECK: iree_codegen.inner_tiled

0 commit comments

Comments
 (0)