Skip to content

Commit 67dc896

Browse files
authored
feat: allow symbolref for kernel_call/jit_call (#1438)
* feat: allow symbolref for kernel_call/jit_call * chore: run fmt * test: use specific reactant branch
1 parent 936ed09 commit 67dc896

File tree

5 files changed

+25
-35
lines changed

5 files changed

+25
-35
lines changed

.github/workflows/test-gb-25.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,8 @@ jobs:
5353
- 'main'
5454
# - '0123456789abcdef0123456789abcdef01234567'
5555
reactant_commit:
56-
- 'main'
56+
# - 'main'
57+
- ap/symbol_ref_kernel_call
5758

5859
steps:
5960
- name: Check GPUs

src/enzyme_ad/jax/Dialect/EnzymeXLAOps.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def KernelCallOp: EnzymeXLA_Op<"kernel_call", [DeclareOpInterfaceMethods<SymbolU
3838
let summary = "Kernel Call operation";
3939

4040
let arguments = (ins
41-
FlatSymbolRefAttr:$fn,
41+
SymbolRefAttr:$fn,
4242
TensorI64:$gridx,
4343
TensorI64:$gridy,
4444
TensorI64:$gridz,
@@ -108,7 +108,7 @@ def JITCallOp: EnzymeXLA_Op<"jit_call", [DeclareOpInterfaceMethods<SymbolUserOpI
108108
let summary = "JIT Call operation";
109109

110110
let arguments = (ins
111-
FlatSymbolRefAttr:$fn,
111+
SymbolRefAttr:$fn,
112112
Variadic<AnyType>:$inputs,
113113
DefaultValuedStrAttr<StrAttr, "">:$backend_config,
114114
OptionalAttr<AnyAttr>:$operand_layouts,

src/enzyme_ad/jax/Dialect/Ops.cpp

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -108,12 +108,13 @@ KernelCallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
108108
}
109109

110110
void KernelCallOp::setCalleeFromCallable(CallInterfaceCallable callee) {
111-
auto symbol = cast<SymbolRefAttr>(callee);
112-
setFnAttr(cast<FlatSymbolRefAttr>(symbol));
111+
setFnAttr(cast<SymbolRefAttr>(callee));
113112
}
114113

115114
CallInterfaceCallable KernelCallOp::getCallableForCallee() {
116-
return SymbolRefAttr::get(getContext(), getFn());
115+
auto attr = getFnAttr();
116+
return SymbolRefAttr::get(getContext(), attr.getRootReference(),
117+
attr.getNestedReferences());
117118
}
118119

119120
Operation::operand_range KernelCallOp::getArgOperands() { return getInputs(); }
@@ -157,8 +158,7 @@ void KernelCallOp::getEffects(
157158
ModuleOp moduleOp = (*this)->getParentOfType<ModuleOp>();
158159
assert(moduleOp && "KernelCallOp must be inside a ModuleOp");
159160

160-
auto callee =
161-
moduleOp.lookupSymbol<FunctionOpInterface>(getFnAttr().getAttr());
161+
auto callee = moduleOp.lookupSymbol<FunctionOpInterface>(getFnAttr());
162162
assert(callee && "KernelCallOp must have a valid function");
163163

164164
auto effectsAttr =
@@ -184,12 +184,13 @@ LogicalResult JITCallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
184184
}
185185

186186
void JITCallOp::setCalleeFromCallable(CallInterfaceCallable callee) {
187-
auto symbol = cast<SymbolRefAttr>(callee);
188-
setFnAttr(cast<FlatSymbolRefAttr>(symbol));
187+
setFnAttr(cast<SymbolRefAttr>(callee));
189188
}
190189

191190
CallInterfaceCallable JITCallOp::getCallableForCallee() {
192-
return SymbolRefAttr::get(getContext(), getFn());
191+
auto attr = getFnAttr();
192+
return SymbolRefAttr::get(getContext(), attr.getRootReference(),
193+
attr.getNestedReferences());
193194
}
194195

195196
MutableOperandRange JITCallOp::getArgOperandsMutable() {
@@ -204,8 +205,7 @@ void JITCallOp::getEffects(
204205
ModuleOp moduleOp = (*this)->getParentOfType<ModuleOp>();
205206
assert(moduleOp && "JITCallOp must be inside a ModuleOp");
206207

207-
auto callee =
208-
moduleOp.lookupSymbol<FunctionOpInterface>(getFnAttr().getAttr());
208+
auto callee = moduleOp.lookupSymbol<FunctionOpInterface>(getFnAttr());
209209
assert(callee && "JITCallOp must have a valid function");
210210

211211
auto effectsAttr =
@@ -1757,8 +1757,7 @@ XLAWrapperOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
17571757
}
17581758

17591759
void XLAWrapperOp::setCalleeFromCallable(CallInterfaceCallable callee) {
1760-
auto symbol = cast<SymbolRefAttr>(callee);
1761-
setFnAttr(cast<FlatSymbolRefAttr>(symbol));
1760+
setFnAttr(cast<SymbolRefAttr>(callee));
17621761
}
17631762

17641763
CallInterfaceCallable XLAWrapperOp::getCallableForCallee() { return getFn(); }

src/enzyme_ad/jax/Passes/LowerKernel.cpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -229,11 +229,11 @@ bool CompileGPUKernel(SymbolTableCollection &symbolTable, mlir::Location loc,
229229
OpBuilder rewriter(kcall);
230230
auto replacement = rewriter.create<enzymexla::JITCallOp>(
231231
kcall.getLoc(), kcall.getResultTypes(),
232-
mlir::FlatSymbolRefAttr::get(kcall.getContext(), callName),
233-
kcall.getInputs(), kcall.getBackendConfigAttr(),
234-
kcall.getOperandLayoutsAttr(), kcall.getResultLayoutsAttr(),
235-
kcall.getArgAttrsAttr(), kcall.getResAttrsAttr(),
236-
kcall.getOutputOperandAliasesAttr(), kcall.getXlaSideEffectFreeAttr());
232+
SymbolRefAttr::get(kcall.getContext(), callName, {}), kcall.getInputs(),
233+
kcall.getBackendConfigAttr(), kcall.getOperandLayoutsAttr(),
234+
kcall.getResultLayoutsAttr(), kcall.getArgAttrsAttr(),
235+
kcall.getResAttrsAttr(), kcall.getOutputOperandAliasesAttr(),
236+
kcall.getXlaSideEffectFreeAttr());
237237
kcall.replaceAllUsesWith(replacement);
238238
kcall.erase();
239239
return true;
@@ -398,11 +398,11 @@ bool CompileCPUKernel(SymbolTableCollection &symbolTable, mlir::Location loc,
398398
OpBuilder rewriter(kcall);
399399
auto replacement = rewriter.create<enzymexla::JITCallOp>(
400400
kcall.getLoc(), kcall.getResultTypes(),
401-
mlir::FlatSymbolRefAttr::get(kcall.getContext(), callName),
402-
kcall.getInputs(), kcall.getBackendConfigAttr(),
403-
kcall.getOperandLayoutsAttr(), kcall.getResultLayoutsAttr(),
404-
kcall.getArgAttrsAttr(), kcall.getResAttrsAttr(),
405-
kcall.getOutputOperandAliasesAttr(), kcall.getXlaSideEffectFreeAttr());
401+
SymbolRefAttr::get(kcall.getContext(), callName, {}), kcall.getInputs(),
402+
kcall.getBackendConfigAttr(), kcall.getOperandLayoutsAttr(),
403+
kcall.getResultLayoutsAttr(), kcall.getArgAttrsAttr(),
404+
kcall.getResAttrsAttr(), kcall.getOutputOperandAliasesAttr(),
405+
kcall.getXlaSideEffectFreeAttr());
406406
kcall.replaceAllUsesWith(replacement);
407407
kcall.erase();
408408
return true;

src/enzyme_ad/jax/Passes/RemoveDuplicateFuncDef.cpp

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -93,16 +93,6 @@ struct RemoveDuplicateFuncDefPass
9393
[&](SymbolUserOpInterface callOp) { callOps.push_back(callOp); });
9494

9595
for (SymbolUserOpInterface symbolUserOp : callOps) {
96-
if (auto kernelOp =
97-
dyn_cast<enzymexla::KernelCallOp>(symbolUserOp.getOperation())) {
98-
auto sym = kernelOp.getFn();
99-
Operation *op = symbolTable.lookup(sym);
100-
assert(op && "Kernel function not found");
101-
auto funcOp = dyn_cast<FunctionOpInterface>(op);
102-
assert(funcOp && "Kernel function is not a function");
103-
if (equivalenceMap.count(funcOp.getNameAttr()))
104-
kernelOp.setFn(equivalenceMap[funcOp.getNameAttr()]);
105-
}
10696
if (auto callOp =
10797
dyn_cast<CallOpInterface>(symbolUserOp.getOperation())) {
10898
SymbolRefAttr sym = llvm::dyn_cast_if_present<SymbolRefAttr>(

0 commit comments

Comments
 (0)