Skip to content

Commit 22ac7f2

Browse files
committed
removal interface for enzymexla.gpu_wrapper
1 parent b0a34c3 commit 22ac7f2

File tree

2 files changed

+90
-7
lines changed

2 files changed

+90
-7
lines changed

src/enzyme_ad/jax/Implementations/EnzymeXLAAutoDiffOpInterfaceImpl.cpp

Lines changed: 90 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,13 @@
1515
#include "Enzyme/MLIR/Interfaces/AutoDiffOpInterface.h"
1616
#include "Enzyme/MLIR/Interfaces/GradientUtils.h"
1717
#include "Enzyme/MLIR/Interfaces/GradientUtilsReverse.h"
18+
#include "mlir/Dialect/Affine/IR/AffineOps.h"
19+
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
20+
#include "mlir/Dialect/MemRef/IR/MemRef.h"
21+
#include "mlir/Dialect/SCF/IR/SCF.h"
1822
#include "mlir/IR/DialectRegistry.h"
1923
#include "mlir/Support/LogicalResult.h"
24+
#include "mlir/Transforms/RegionUtils.h"
2025

2126
#include "Dialect/Ops.h"
2227
#include "mlir/IR/TypeSupport.h"
@@ -68,12 +73,92 @@ struct GPUWrapperOpEnzymeOpsRemover
6873
if (gradients.empty() && pushedCaches.empty())
6974
return success();
7075

71-
if (gradients.size())
72-
return failure();
76+
llvm::MapVector<Value, CacheInfo> cachesMap;
77+
for (auto &it : *wrapOp.getBody()) {
78+
Operation *op = &it;
79+
if (auto pushOp = dyn_cast<enzyme::PushOp>(op)) {
80+
CacheInfo info(pushOp.getCache());
81+
if (cachesMap.contains(pushOp.getValue()))
82+
info = info.merge(cachesMap.lookup(pushOp.getValue()), rewriter);
83+
cachesMap[pushOp.getValue()] = info;
84+
}
85+
}
86+
SmallVector<CacheInfo> caches =
87+
llvm::map_to_vector(cachesMap, [](auto p) { return std::get<1>(p); });
88+
89+
if (caches.empty())
90+
return success();
91+
92+
SetVector<Value> visited;
93+
getUsedValuesDefinedAbove(wrapOp.getBodyRegion(), visited);
94+
SmallVector<Value> frontier = llvm::map_to_vector(
95+
caches, [](CacheInfo info) { return info.pushedValue(); });
96+
SetVector<Operation *> opsToMove;
97+
// Traverse backward from pushed values to find operations that the pushed
98+
// value depends on
99+
while (!frontier.empty()) {
100+
Value v = frontier.back();
101+
Operation *definingOp = v.getDefiningOp();
102+
frontier.pop_back();
103+
104+
if (!definingOp)
105+
continue;
106+
107+
// Assume allocations and frees are legal to move
108+
if (hasEffect<MemoryEffects::Read>(definingOp) ||
109+
hasEffect<MemoryEffects::Write>(definingOp)) {
110+
definingOp->emitError() << "cannot move op with side effects";
111+
return failure();
112+
}
113+
opsToMove.insert(definingOp);
114+
115+
for (Value operand : definingOp->getOperands()) {
116+
if (visited.contains(operand))
117+
continue;
118+
119+
frontier.push_back(operand);
120+
visited.insert(operand);
121+
}
122+
}
73123

74-
if (pushedCaches.size())
75-
return failure();
124+
// Move the push and dependent values outside of the wrapper
125+
OpBuilder::InsertionGuard guard(rewriter);
126+
IRMapping map;
127+
rewriter.setInsertionPoint(wrapOp);
128+
for (Operation *toMove : llvm::reverse(opsToMove)) {
129+
Operation *cloned = rewriter.clone(*toMove, map);
130+
toMove->replaceAllUsesWith(cloned->getResults());
131+
132+
if (auto allocOp = dyn_cast<memref::AllocOp>(cloned)) {
133+
// Assume GPU allocations need to be in address space 1
134+
auto gpuAlloc = gpu::AllocOp::create(
135+
rewriter, allocOp.getLoc(),
136+
*allocOp.getType().clonePtrWith(rewriter.getI64IntegerAttr(1),
137+
std::nullopt),
138+
/*asyncDependencies=*/ValueRange(), allocOp.getDynamicSizes(),
139+
/*symbolOperands=*/ValueRange());
140+
allocOp.replaceAllUsesWith(gpuAlloc.getResult(0));
141+
rewriter.eraseOp(allocOp);
142+
}
143+
}
76144

145+
for (auto &info : caches) {
146+
rewriter.moveOpBefore(info.pushOp, wrapOp);
147+
auto revWrapper = info.popOp->getParentOfType<enzymexla::GPUWrapperOp>();
148+
assert(revWrapper && "failed to find reverse gpu_wrapper");
149+
rewriter.moveOpBefore(info.popOp, revWrapper);
150+
151+
for (auto user : info.popOp.getResult().getUsers()) {
152+
if (isa<memref::DeallocOp>(user)) {
153+
rewriter.eraseOp(user);
154+
}
155+
}
156+
rewriter.setInsertionPointAfter(revWrapper);
157+
gpu::DeallocOp::create(rewriter, wrapOp.getLoc(), TypeRange(),
158+
info.popOp.getResult());
159+
}
160+
161+
return success();
77162
// TODO need to convert to gpu allocations and conversion/copy
78163

79164
/*
@@ -213,7 +298,7 @@ class Pointer2MemrefRev : public ReverseAutoDiffOpInterface::ExternalModel<
213298
Value dres = gutils->invertPointerM(p2m.getSource(), builder);
214299
Value shadow = builder.create<enzymexla::Pointer2MemrefOp>(
215300
p2m.getLoc(), p2m.getType(), dres);
216-
gutils->setDiffe(p2m, shadow, builder);
301+
gutils->setInvertedPointer(p2m, shadow);
217302
}
218303
}
219304
};

src/enzyme_ad/jax/Implementations/XLADerivatives.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,13 @@
1010

1111
namespace mlir {
1212
namespace enzyme {
13-
void registerEnzymeXLADialectAutoDiffInterface(mlir::DialectRegistry &registry);
1413
void registerMHLODialectAutoDiffInterface(mlir::DialectRegistry &registry);
1514
void registerStableHLODialectAutoDiffInterface(mlir::DialectRegistry &registry);
1615
void registerCHLODialectAutoDiffInterface(mlir::DialectRegistry &registry);
1716
void registerEnzymeXLADialectAutoDiffInterface(mlir::DialectRegistry &registry);
1817

1918
static inline void
2019
registerXLAAutoDiffInterfaces(mlir::DialectRegistry &registry) {
21-
registerEnzymeXLADialectAutoDiffInterface(registry);
2220
registerMHLODialectAutoDiffInterface(registry);
2321
registerStableHLODialectAutoDiffInterface(registry);
2422
registerCHLODialectAutoDiffInterface(registry);

0 commit comments

Comments
 (0)