Skip to content

Commit 7020f03

Browse files
committed
removal interface for enzymexla.gpu_wrapper
1 parent a37bcd4 commit 7020f03

File tree

2 files changed

+90
-7
lines changed

2 files changed

+90
-7
lines changed

src/enzyme_ad/jax/Implementations/EnzymeXLAAutoDiffOpInterfaceImpl.cpp

Lines changed: 90 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,13 @@
1515
#include "Enzyme/MLIR/Interfaces/AutoDiffOpInterface.h"
1616
#include "Enzyme/MLIR/Interfaces/GradientUtils.h"
1717
#include "Enzyme/MLIR/Interfaces/GradientUtilsReverse.h"
18+
#include "mlir/Dialect/Affine/IR/AffineOps.h"
19+
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
20+
#include "mlir/Dialect/MemRef/IR/MemRef.h"
21+
#include "mlir/Dialect/SCF/IR/SCF.h"
1822
#include "mlir/IR/DialectRegistry.h"
1923
#include "mlir/Support/LogicalResult.h"
24+
#include "mlir/Transforms/RegionUtils.h"
2025
#include "src/enzyme_ad/jax/Implementations/SHLOGenericBatchOpInterface.h"
2126

2227
#include "Dialect/Ops.h"
@@ -69,12 +74,92 @@ struct GPUWrapperOpEnzymeOpsRemover
6974
if (gradients.empty() && pushedCaches.empty())
7075
return success();
7176

72-
if (gradients.size())
73-
return failure();
77+
llvm::MapVector<Value, CacheInfo> cachesMap;
78+
for (auto &it : *wrapOp.getBody()) {
79+
Operation *op = &it;
80+
if (auto pushOp = dyn_cast<enzyme::PushOp>(op)) {
81+
CacheInfo info(pushOp.getCache());
82+
if (cachesMap.contains(pushOp.getValue()))
83+
info = info.merge(cachesMap.lookup(pushOp.getValue()), rewriter);
84+
cachesMap[pushOp.getValue()] = info;
85+
}
86+
}
87+
SmallVector<CacheInfo> caches =
88+
llvm::map_to_vector(cachesMap, [](auto p) { return std::get<1>(p); });
89+
90+
if (caches.empty())
91+
return success();
92+
93+
SetVector<Value> visited;
94+
getUsedValuesDefinedAbove(wrapOp.getBodyRegion(), visited);
95+
SmallVector<Value> frontier = llvm::map_to_vector(
96+
caches, [](CacheInfo info) { return info.pushedValue(); });
97+
SetVector<Operation *> opsToMove;
98+
// Traverse backward from pushed values to find operations that the pushed
99+
// value depends on
100+
while (!frontier.empty()) {
101+
Value v = frontier.back();
102+
Operation *definingOp = v.getDefiningOp();
103+
frontier.pop_back();
104+
105+
if (!definingOp)
106+
continue;
107+
108+
// Assume allocations and frees are legal to move
109+
if (hasEffect<MemoryEffects::Read>(definingOp) ||
110+
hasEffect<MemoryEffects::Write>(definingOp)) {
111+
definingOp->emitError() << "cannot move op with side effects";
112+
return failure();
113+
}
114+
opsToMove.insert(definingOp);
115+
116+
for (Value operand : definingOp->getOperands()) {
117+
if (visited.contains(operand))
118+
continue;
119+
120+
frontier.push_back(operand);
121+
visited.insert(operand);
122+
}
123+
}
74124

75-
if (pushedCaches.size())
76-
return failure();
125+
// Move the push and dependent values outside of the wrapper
126+
OpBuilder::InsertionGuard guard(rewriter);
127+
IRMapping map;
128+
rewriter.setInsertionPoint(wrapOp);
129+
for (Operation *toMove : llvm::reverse(opsToMove)) {
130+
Operation *cloned = rewriter.clone(*toMove, map);
131+
toMove->replaceAllUsesWith(cloned->getResults());
132+
133+
if (auto allocOp = dyn_cast<memref::AllocOp>(cloned)) {
134+
// Assume GPU allocations need to be in address space 1
135+
auto gpuAlloc = gpu::AllocOp::create(
136+
rewriter, allocOp.getLoc(),
137+
*allocOp.getType().clonePtrWith(rewriter.getI64IntegerAttr(1),
138+
std::nullopt),
139+
/*asyncDependencies=*/ValueRange(), allocOp.getDynamicSizes(),
140+
/*symbolOperands=*/ValueRange());
141+
allocOp.replaceAllUsesWith(gpuAlloc.getResult(0));
142+
rewriter.eraseOp(allocOp);
143+
}
144+
}
77145

146+
for (auto &info : caches) {
147+
rewriter.moveOpBefore(info.pushOp, wrapOp);
148+
auto revWrapper = info.popOp->getParentOfType<enzymexla::GPUWrapperOp>();
149+
assert(revWrapper && "failed to find reverse gpu_wrapper");
150+
rewriter.moveOpBefore(info.popOp, revWrapper);
151+
152+
for (auto user : info.popOp.getResult().getUsers()) {
153+
if (isa<memref::DeallocOp>(user)) {
154+
rewriter.eraseOp(user);
155+
}
156+
}
157+
rewriter.setInsertionPointAfter(revWrapper);
158+
gpu::DeallocOp::create(rewriter, wrapOp.getLoc(), TypeRange(),
159+
info.popOp.getResult());
160+
}
161+
162+
return success();
78163
// TODO need to convert to gpu allocations and conversion/copy
79164

80165
/*
@@ -214,7 +299,7 @@ class Pointer2MemrefRev : public ReverseAutoDiffOpInterface::ExternalModel<
214299
Value dres = gutils->invertPointerM(p2m.getSource(), builder);
215300
Value shadow = builder.create<enzymexla::Pointer2MemrefOp>(
216301
p2m.getLoc(), p2m.getType(), dres);
217-
gutils->setDiffe(p2m, shadow, builder);
302+
gutils->setInvertedPointer(p2m, shadow);
218303
}
219304
}
220305
};

src/enzyme_ad/jax/Implementations/XLADerivatives.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,13 @@
1010

1111
namespace mlir {
1212
namespace enzyme {
13-
void registerEnzymeXLADialectAutoDiffInterface(mlir::DialectRegistry &registry);
1413
void registerMHLODialectAutoDiffInterface(mlir::DialectRegistry &registry);
1514
void registerStableHLODialectAutoDiffInterface(mlir::DialectRegistry &registry);
1615
void registerCHLODialectAutoDiffInterface(mlir::DialectRegistry &registry);
1716
void registerEnzymeXLADialectAutoDiffInterface(mlir::DialectRegistry &registry);
1817

1918
static inline void
2019
registerXLAAutoDiffInterfaces(mlir::DialectRegistry &registry) {
21-
registerEnzymeXLADialectAutoDiffInterface(registry);
2220
registerMHLODialectAutoDiffInterface(registry);
2321
registerStableHLODialectAutoDiffInterface(registry);
2422
registerCHLODialectAutoDiffInterface(registry);

0 commit comments

Comments
 (0)