@@ -47,6 +47,23 @@ using namespace enzymexla;
4747
4848using namespace stablehlo ;
4949
50+ SymbolRefAttr SymRefAttrReplacingFunctionName (SymbolRefAttr origSymRef,
51+ StringRef newName) {
52+ auto newNameRef = FlatSymbolRefAttr::get (origSymRef.getContext (), newName);
53+ auto nestedRefsAttr = origSymRef.getNestedReferences ();
54+ if (nestedRefsAttr.size () == 0 ) {
55+ return newNameRef;
56+ }
57+
58+ auto rootRef = origSymRef.getRootReference ();
59+ SmallVector<FlatSymbolRefAttr> nestedRefs;
60+ for (int i = 0 ; i < nestedRefsAttr.size () - 1 ; i++) {
61+ nestedRefs.push_back (nestedRefsAttr[i]);
62+ }
63+ nestedRefs.push_back (newNameRef);
64+ return SymbolRefAttr::get (origSymRef.getContext (), rootRef, nestedRefs);
65+ }
66+
5067bool CompileGPUKernel (SymbolTableCollection &symbolTable, mlir::Location loc,
5168 FunctionOpInterface op, size_t gridx, size_t gridy,
5269 size_t gridz, size_t blockx, size_t blocky, size_t blockz,
@@ -229,11 +246,11 @@ bool CompileGPUKernel(SymbolTableCollection &symbolTable, mlir::Location loc,
229246 OpBuilder rewriter (kcall);
230247 auto replacement = rewriter.create <enzymexla::JITCallOp>(
231248 kcall.getLoc (), kcall.getResultTypes (),
232- SymbolRefAttr::get (kcall.getContext (), callName, {}), kcall. getInputs ( ),
233- kcall.getBackendConfigAttr (), kcall.getOperandLayoutsAttr (),
234- kcall.getResultLayoutsAttr (), kcall.getArgAttrsAttr (),
235- kcall.getResAttrsAttr (), kcall.getOutputOperandAliasesAttr (),
236- kcall.getXlaSideEffectFreeAttr ());
249+ SymRefAttrReplacingFunctionName (kcall.getFn (), callName),
250+ kcall.getInputs (), kcall.getBackendConfigAttr (),
251+ kcall.getOperandLayoutsAttr (), kcall.getResultLayoutsAttr (),
252+ kcall.getArgAttrsAttr (), kcall.getResAttrsAttr (),
253+ kcall.getOutputOperandAliasesAttr (), kcall. getXlaSideEffectFreeAttr ());
237254 kcall.replaceAllUsesWith (replacement);
238255 kcall.erase ();
239256 return true ;
@@ -398,11 +415,11 @@ bool CompileCPUKernel(SymbolTableCollection &symbolTable, mlir::Location loc,
398415 OpBuilder rewriter (kcall);
399416 auto replacement = rewriter.create <enzymexla::JITCallOp>(
400417 kcall.getLoc (), kcall.getResultTypes (),
401- SymbolRefAttr::get (kcall.getContext (), callName, {}), kcall. getInputs ( ),
402- kcall.getBackendConfigAttr (), kcall.getOperandLayoutsAttr (),
403- kcall.getResultLayoutsAttr (), kcall.getArgAttrsAttr (),
404- kcall.getResAttrsAttr (), kcall.getOutputOperandAliasesAttr (),
405- kcall.getXlaSideEffectFreeAttr ());
418+ SymRefAttrReplacingFunctionName (kcall.getFn (), callName),
419+ kcall.getInputs (), kcall.getBackendConfigAttr (),
420+ kcall.getOperandLayoutsAttr (), kcall.getResultLayoutsAttr (),
421+ kcall.getArgAttrsAttr (), kcall.getResAttrsAttr (),
422+ kcall.getOutputOperandAliasesAttr (), kcall. getXlaSideEffectFreeAttr ());
406423 kcall.replaceAllUsesWith (replacement);
407424 kcall.erase ();
408425 return true ;
0 commit comments