|
15 | 15 | #include "Enzyme/MLIR/Interfaces/AutoDiffOpInterface.h" |
16 | 16 | #include "Enzyme/MLIR/Interfaces/GradientUtils.h" |
17 | 17 | #include "Enzyme/MLIR/Interfaces/GradientUtilsReverse.h" |
| 18 | +#include "mlir/Dialect/Affine/IR/AffineOps.h" |
| 19 | +#include "mlir/Dialect/GPU/IR/GPUDialect.h" |
| 20 | +#include "mlir/Dialect/MemRef/IR/MemRef.h" |
| 21 | +#include "mlir/Dialect/SCF/IR/SCF.h" |
18 | 22 | #include "mlir/IR/DialectRegistry.h" |
19 | 23 | #include "mlir/Support/LogicalResult.h" |
| 24 | +#include "mlir/Transforms/RegionUtils.h" |
20 | 25 |
|
21 | 26 | #include "Dialect/Ops.h" |
22 | 27 | #include "mlir/IR/TypeSupport.h" |
@@ -68,12 +73,92 @@ struct GPUWrapperOpEnzymeOpsRemover |
68 | 73 | if (gradients.empty() && pushedCaches.empty()) |
69 | 74 | return success(); |
70 | 75 |
|
71 | | - if (gradients.size()) |
72 | | - return failure(); |
| 76 | + llvm::MapVector<Value, CacheInfo> cachesMap; |
| 77 | + for (auto &it : *wrapOp.getBody()) { |
| 78 | + Operation *op = ⁢ |
| 79 | + if (auto pushOp = dyn_cast<enzyme::PushOp>(op)) { |
| 80 | + CacheInfo info(pushOp.getCache()); |
| 81 | + if (cachesMap.contains(pushOp.getValue())) |
| 82 | + info = info.merge(cachesMap.lookup(pushOp.getValue()), rewriter); |
| 83 | + cachesMap[pushOp.getValue()] = info; |
| 84 | + } |
| 85 | + } |
| 86 | + SmallVector<CacheInfo> caches = |
| 87 | + llvm::map_to_vector(cachesMap, [](auto p) { return std::get<1>(p); }); |
| 88 | + |
| 89 | + if (caches.empty()) |
| 90 | + return success(); |
| 91 | + |
| 92 | + SetVector<Value> visited; |
| 93 | + getUsedValuesDefinedAbove(wrapOp.getBodyRegion(), visited); |
| 94 | + SmallVector<Value> frontier = llvm::map_to_vector( |
| 95 | + caches, [](CacheInfo info) { return info.pushedValue(); }); |
| 96 | + SetVector<Operation *> opsToMove; |
| 97 | + // Traverse backward from pushed values to find operations that the pushed |
| 98 | + // value depends on |
| 99 | + while (!frontier.empty()) { |
| 100 | + Value v = frontier.back(); |
| 101 | + Operation *definingOp = v.getDefiningOp(); |
| 102 | + frontier.pop_back(); |
| 103 | + |
| 104 | + if (!definingOp) |
| 105 | + continue; |
| 106 | + |
| 107 | + // Assume allocations and frees are legal to move |
| 108 | + if (hasEffect<MemoryEffects::Read>(definingOp) || |
| 109 | + hasEffect<MemoryEffects::Write>(definingOp)) { |
| 110 | + definingOp->emitError() << "cannot move op with side effects"; |
| 111 | + return failure(); |
| 112 | + } |
| 113 | + opsToMove.insert(definingOp); |
| 114 | + |
| 115 | + for (Value operand : definingOp->getOperands()) { |
| 116 | + if (visited.contains(operand)) |
| 117 | + continue; |
| 118 | + |
| 119 | + frontier.push_back(operand); |
| 120 | + visited.insert(operand); |
| 121 | + } |
| 122 | + } |
73 | 123 |
|
74 | | - if (pushedCaches.size()) |
75 | | - return failure(); |
| 124 | + // Move the push and dependent values outside of the wrapper |
| 125 | + OpBuilder::InsertionGuard guard(rewriter); |
| 126 | + IRMapping map; |
| 127 | + rewriter.setInsertionPoint(wrapOp); |
| 128 | + for (Operation *toMove : llvm::reverse(opsToMove)) { |
| 129 | + Operation *cloned = rewriter.clone(*toMove, map); |
| 130 | + toMove->replaceAllUsesWith(cloned->getResults()); |
| 131 | + |
| 132 | + if (auto allocOp = dyn_cast<memref::AllocOp>(cloned)) { |
| 133 | + // Assume GPU allocations need to be in address space 1 |
| 134 | + auto gpuAlloc = gpu::AllocOp::create( |
| 135 | + rewriter, allocOp.getLoc(), |
| 136 | + *allocOp.getType().clonePtrWith(rewriter.getI64IntegerAttr(1), |
| 137 | + std::nullopt), |
| 138 | + /*asyncDependencies=*/ValueRange(), allocOp.getDynamicSizes(), |
| 139 | + /*symbolOperands=*/ValueRange()); |
| 140 | + allocOp.replaceAllUsesWith(gpuAlloc.getResult(0)); |
| 141 | + rewriter.eraseOp(allocOp); |
| 142 | + } |
| 143 | + } |
76 | 144 |
|
| 145 | + for (auto &info : caches) { |
| 146 | + rewriter.moveOpBefore(info.pushOp, wrapOp); |
| 147 | + auto revWrapper = info.popOp->getParentOfType<enzymexla::GPUWrapperOp>(); |
| 148 | + assert(revWrapper && "failed to find reverse gpu_wrapper"); |
| 149 | + rewriter.moveOpBefore(info.popOp, revWrapper); |
| 150 | + |
| 151 | + for (auto user : info.popOp.getResult().getUsers()) { |
| 152 | + if (isa<memref::DeallocOp>(user)) { |
| 153 | + rewriter.eraseOp(user); |
| 154 | + } |
| 155 | + } |
| 156 | + rewriter.setInsertionPointAfter(revWrapper); |
| 157 | + gpu::DeallocOp::create(rewriter, wrapOp.getLoc(), TypeRange(), |
| 158 | + info.popOp.getResult()); |
| 159 | + } |
| 160 | + |
| 161 | + return success(); |
77 | 162 | // TODO need to convert to gpu allocations and conversion/copy |
78 | 163 |
|
79 | 164 | /* |
@@ -213,7 +298,7 @@ class Pointer2MemrefRev : public ReverseAutoDiffOpInterface::ExternalModel< |
213 | 298 | Value dres = gutils->invertPointerM(p2m.getSource(), builder); |
214 | 299 | Value shadow = builder.create<enzymexla::Pointer2MemrefOp>( |
215 | 300 | p2m.getLoc(), p2m.getType(), dres); |
216 | | - gutils->setDiffe(p2m, shadow, builder); |
| 301 | + gutils->setInvertedPointer(p2m, shadow); |
217 | 302 | } |
218 | 303 | } |
219 | 304 | }; |
|
0 commit comments