Skip to content

Commit 1631a64

Browse files
authored
shlo raise (#1041)
* shlo raise * fmt
1 parent 7872c77 commit 1631a64

File tree

3 files changed

+129
-0
lines changed

3 files changed

+129
-0
lines changed

src/enzyme_ad/jax/Dialect/EnzymeXLAOps.td

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,20 @@ def GPUWrapperOp : EnzymeXLA_Op<"gpu_wrapper", [
152152
OpBuilder<(ins)>];
153153
}
154154

155+
def XLAWrapperOp: EnzymeXLA_Op<"xla_wrapper", [DeclareOpInterfaceMethods<SymbolUserOpInterface>, DeclareOpInterfaceMethods<CallOpInterface>, DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
156+
let summary = "XLA Call operation";
157+
158+
let arguments = (ins
159+
SymbolRefAttr:$fn,
160+
Variadic<AnyType>:$inputs
161+
);
162+
163+
let assemblyFormat = [{
164+
$fn ` ` `(` $inputs `)` attr-dict `:` functional-type($inputs, results)
165+
}];
166+
167+
}
168+
155169
def GPUErrorOp : EnzymeXLA_Op<"gpu_error", [
156170
RecursiveMemoryEffects,
157171
SingleBlockImplicitTerminator<"enzymexla::PolygeistYieldOp">]>,

src/enzyme_ad/jax/Dialect/Ops.cpp

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1426,3 +1426,48 @@ void GPUErrorOp::build(OpBuilder &builder, OperationState &result) {
14261426
builder.createBlock(bodyRegion);
14271427
GPUErrorOp::ensureTerminator(*bodyRegion, builder, result.location);
14281428
}
1429+
1430+
LogicalResult
1431+
XLAWrapperOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1432+
// TODO: Verify that the result type is same as the type of the referenced
1433+
// func.func op.
1434+
auto global = symbolTable.lookupNearestSymbolFrom<FunctionOpInterface>(
1435+
*this, getFnAttr());
1436+
if (!global)
1437+
return emitOpError("'")
1438+
<< getFn() << "' does not reference a valid global funcOp";
1439+
1440+
return success();
1441+
}
1442+
1443+
void XLAWrapperOp::setCalleeFromCallable(CallInterfaceCallable callee) {
1444+
auto symbol = cast<SymbolRefAttr>(callee);
1445+
setFnAttr(cast<FlatSymbolRefAttr>(symbol));
1446+
}
1447+
1448+
CallInterfaceCallable XLAWrapperOp::getCallableForCallee() { return getFn(); }
1449+
1450+
MutableOperandRange XLAWrapperOp::getArgOperandsMutable() {
1451+
return getInputsMutable();
1452+
}
1453+
1454+
Operation::operand_range XLAWrapperOp::getArgOperands() { return getInputs(); }
1455+
1456+
ArrayAttr XLAWrapperOp::getArgAttrsAttr() { return nullptr; }
1457+
1458+
void XLAWrapperOp::setArgAttrsAttr(mlir::ArrayAttr attr) { (void)attr; }
1459+
1460+
ArrayAttr XLAWrapperOp::getResAttrsAttr() { return nullptr; }
1461+
1462+
void XLAWrapperOp::setResAttrsAttr(ArrayAttr attr) { (void)attr; }
1463+
1464+
Attribute XLAWrapperOp::removeArgAttrsAttr() { return nullptr; }
1465+
1466+
Attribute XLAWrapperOp::removeResAttrsAttr() { return nullptr; }
1467+
1468+
void XLAWrapperOp::getEffects(
1469+
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1470+
&effects) {
1471+
effects.emplace_back(MemoryEffects::Effect::get<MemoryEffects::Read>());
1472+
effects.emplace_back(MemoryEffects::Effect::get<MemoryEffects::Write>());
1473+
}

src/enzyme_ad/jax/Passes/AffineToStableHLORaising.cpp

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2451,6 +2451,76 @@ struct AffineToStableHLORaisingPass
24512451
}
24522452
funcs.pop_back();
24532453
}
2454+
std::vector<enzymexla::GPUWrapperOp> gwrap;
2455+
op->walk([&](enzymexla::GPUWrapperOp g) { gwrap.push_back(g); });
2456+
for (auto g : gwrap) {
2457+
auto modOp = g->getParentOfType<ModuleOp>();
2458+
Block *body = &g->getRegion(0).front();
2459+
Block *newBlock = new Block();
2460+
2461+
IRMapping mapping;
2462+
mapping.map(body, newBlock);
2463+
SetVector<Value> operands;
2464+
getUsedValuesDefinedAbove(g->getRegion(0), operands);
2465+
2466+
SmallVector<Type> tensorTypes;
2467+
for (auto arg : operands) {
2468+
auto MT = cast<MemRefType>(arg.getType());
2469+
auto TT = RankedTensorType::get(MT.getShape(), MT.getElementType());
2470+
auto newArg = newBlock->addArgument(TT, arg.getLoc());
2471+
mapping.map(arg, newArg);
2472+
tensorTypes.push_back(TT);
2473+
}
2474+
2475+
auto newFuncType =
2476+
FunctionType::get(g->getContext(), tensorTypes, tensorTypes);
2477+
2478+
std::string name = "raised";
2479+
2480+
auto newFunc = func::FuncOp::create(g->getLoc(), name, newFuncType);
2481+
newFunc.setVisibility(mlir::SymbolTable::Visibility::Private);
2482+
newFunc.getBody().push_back(newBlock);
2483+
2484+
OpBuilder builder(newBlock, newBlock->end());
2485+
2486+
bool anyFailed = false;
2487+
2488+
llvm::DenseMap<Value, affine::AffineValueMap> maps;
2489+
2490+
ParallelContext emptyPc = ParallelContext::getEmpty(options);
2491+
for (auto &it : body->without_terminator()) {
2492+
anyFailed =
2493+
tryRaisingOpToStableHLO(&it, mapping, builder, maps, emptyPc)
2494+
.failed();
2495+
if (anyFailed)
2496+
break;
2497+
}
2498+
2499+
if (anyFailed) {
2500+
newFunc->erase();
2501+
return;
2502+
}
2503+
2504+
SmallVector<Value> results;
2505+
for (auto arg : operands) {
2506+
auto val = mapping.lookup(arg);
2507+
results.push_back(val);
2508+
}
2509+
2510+
builder.create<func::ReturnOp>(g->getLoc(), results);
2511+
modOp.getBody()->push_back(newFunc);
2512+
SymbolTable::setSymbolVisibility(newFunc,
2513+
SymbolTable::Visibility::Private);
2514+
2515+
{
2516+
OpBuilder builder(g);
2517+
auto newCall = builder.create<enzymexla::XLAWrapperOp>(
2518+
g->getLoc(), SymbolRefAttr::get(newFunc),
2519+
llvm::to_vector(operands));
2520+
g->erase();
2521+
anyRaised = true;
2522+
}
2523+
}
24542524

24552525
if (!anyRaised) {
24562526
markAllAnalysesPreserved();

0 commit comments

Comments
 (0)