Skip to content

Commit 502d6b5

Browse files
authored
fix: correctly set symrefattr for kernel/jit call (#1448)
1 parent 9b689ed commit 502d6b5

File tree

1 file changed

+27
-10
lines changed

1 file changed

+27
-10
lines changed

src/enzyme_ad/jax/Passes/LowerKernel.cpp

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,23 @@ using namespace enzymexla;
4747

4848
using namespace stablehlo;
4949

50+
SymbolRefAttr SymRefAttrReplacingFunctionName(SymbolRefAttr origSymRef,
51+
StringRef newName) {
52+
auto newNameRef = FlatSymbolRefAttr::get(origSymRef.getContext(), newName);
53+
auto nestedRefsAttr = origSymRef.getNestedReferences();
54+
if (nestedRefsAttr.size() == 0) {
55+
return newNameRef;
56+
}
57+
58+
auto rootRef = origSymRef.getRootReference();
59+
SmallVector<FlatSymbolRefAttr> nestedRefs;
60+
for (int i = 0; i < nestedRefsAttr.size() - 1; i++) {
61+
nestedRefs.push_back(nestedRefsAttr[i]);
62+
}
63+
nestedRefs.push_back(newNameRef);
64+
return SymbolRefAttr::get(origSymRef.getContext(), rootRef, nestedRefs);
65+
}
66+
5067
bool CompileGPUKernel(SymbolTableCollection &symbolTable, mlir::Location loc,
5168
FunctionOpInterface op, size_t gridx, size_t gridy,
5269
size_t gridz, size_t blockx, size_t blocky, size_t blockz,
@@ -229,11 +246,11 @@ bool CompileGPUKernel(SymbolTableCollection &symbolTable, mlir::Location loc,
229246
OpBuilder rewriter(kcall);
230247
auto replacement = rewriter.create<enzymexla::JITCallOp>(
231248
kcall.getLoc(), kcall.getResultTypes(),
232-
SymbolRefAttr::get(kcall.getContext(), callName, {}), kcall.getInputs(),
233-
kcall.getBackendConfigAttr(), kcall.getOperandLayoutsAttr(),
234-
kcall.getResultLayoutsAttr(), kcall.getArgAttrsAttr(),
235-
kcall.getResAttrsAttr(), kcall.getOutputOperandAliasesAttr(),
236-
kcall.getXlaSideEffectFreeAttr());
249+
SymRefAttrReplacingFunctionName(kcall.getFn(), callName),
250+
kcall.getInputs(), kcall.getBackendConfigAttr(),
251+
kcall.getOperandLayoutsAttr(), kcall.getResultLayoutsAttr(),
252+
kcall.getArgAttrsAttr(), kcall.getResAttrsAttr(),
253+
kcall.getOutputOperandAliasesAttr(), kcall.getXlaSideEffectFreeAttr());
237254
kcall.replaceAllUsesWith(replacement);
238255
kcall.erase();
239256
return true;
@@ -398,11 +415,11 @@ bool CompileCPUKernel(SymbolTableCollection &symbolTable, mlir::Location loc,
398415
OpBuilder rewriter(kcall);
399416
auto replacement = rewriter.create<enzymexla::JITCallOp>(
400417
kcall.getLoc(), kcall.getResultTypes(),
401-
SymbolRefAttr::get(kcall.getContext(), callName, {}), kcall.getInputs(),
402-
kcall.getBackendConfigAttr(), kcall.getOperandLayoutsAttr(),
403-
kcall.getResultLayoutsAttr(), kcall.getArgAttrsAttr(),
404-
kcall.getResAttrsAttr(), kcall.getOutputOperandAliasesAttr(),
405-
kcall.getXlaSideEffectFreeAttr());
418+
SymRefAttrReplacingFunctionName(kcall.getFn(), callName),
419+
kcall.getInputs(), kcall.getBackendConfigAttr(),
420+
kcall.getOperandLayoutsAttr(), kcall.getResultLayoutsAttr(),
421+
kcall.getArgAttrsAttr(), kcall.getResAttrsAttr(),
422+
kcall.getOutputOperandAliasesAttr(), kcall.getXlaSideEffectFreeAttr());
406423
kcall.replaceAllUsesWith(replacement);
407424
kcall.erase();
408425
return true;

0 commit comments

Comments
 (0)