Skip to content

Commit a37bcd4

Browse files
committed
enzymexla.pointer2memref derivative
1 parent edb3814 commit a37bcd4

File tree

2 files changed

+30
-2
lines changed

2 files changed

+30
-2
lines changed

src/enzyme_ad/jax/Implementations/EnzymeXLAAutoDiffOpInterfaceImpl.cpp

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
//===- CHLOAutoDiffOpInterfaceImpl.cpp - Interface external model --------===//
1+
//===- EnzymeXLAAutoDiffOpInterfaceImpl.cpp - Interface external model ----===//
22
//
33
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44
// See https://llvm.org/LICENSE.txt for license information.
@@ -7,7 +7,7 @@
77
//===----------------------------------------------------------------------===//
88
//
99
// This file contains the external model implementation of the automatic
10-
// differentiation op interfaces for the upstream MLIR arithmetic dialect.
10+
// differentiation op interfaces for the EnzymeXLA dialect.
1111
//
1212
//===----------------------------------------------------------------------===//
1313

@@ -193,6 +193,31 @@ struct GPUWrapperOpInterfaceReverse
193193
MGradientUtilsReverse *gutils) const {}
194194
};
195195

196+
class Pointer2MemrefRev : public ReverseAutoDiffOpInterface::ExternalModel<
197+
Pointer2MemrefRev, enzymexla::Pointer2MemrefOp> {
198+
public:
199+
LogicalResult createReverseModeAdjoint(Operation *orig, OpBuilder &builder,
200+
MGradientUtilsReverse *gutils,
201+
SmallVector<Value> caches) const {
202+
return success();
203+
}
204+
205+
SmallVector<Value> cacheValues(Operation *orig,
206+
MGradientUtilsReverse *gutils) const {
207+
return SmallVector<Value>();
208+
}
209+
210+
void createShadowValues(Operation *op, OpBuilder &builder,
211+
MGradientUtilsReverse *gutils) const {
212+
auto p2m = cast<enzymexla::Pointer2MemrefOp>(op);
213+
if (!gutils->isConstantValue(p2m)) {
214+
Value dres = gutils->invertPointerM(p2m.getSource(), builder);
215+
Value shadow = builder.create<enzymexla::Pointer2MemrefOp>(
216+
p2m.getLoc(), p2m.getType(), dres);
217+
gutils->setDiffe(p2m, shadow, builder);
218+
}
219+
}
220+
};
196221
} // namespace
197222

198223
void mlir::enzyme::registerEnzymeXLADialectAutoDiffInterface(
@@ -201,6 +226,7 @@ void mlir::enzyme::registerEnzymeXLADialectAutoDiffInterface(
201226
registerInterfaces(context);
202227
GPUWrapperOp::attachInterface<GPUWrapperOpInterfaceReverse>(*context);
203228
GPUWrapperOp::attachInterface<GPUWrapperOpEnzymeOpsRemover>(*context);
229+
enzymexla::Pointer2MemrefOp::attachInterface<Pointer2MemrefRev>(*context);
204230

205231
// Register batching interfaces
206232
JITCallOp::attachInterface<SHLOGenericBatchOpInterface<JITCallOp>>(

src/enzyme_ad/jax/Implementations/XLADerivatives.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,15 @@
1010

1111
namespace mlir {
1212
namespace enzyme {
13+
void registerEnzymeXLADialectAutoDiffInterface(mlir::DialectRegistry &registry);
1314
void registerMHLODialectAutoDiffInterface(mlir::DialectRegistry &registry);
1415
void registerStableHLODialectAutoDiffInterface(mlir::DialectRegistry &registry);
1516
void registerCHLODialectAutoDiffInterface(mlir::DialectRegistry &registry);
1617
void registerEnzymeXLADialectAutoDiffInterface(mlir::DialectRegistry &registry);
1718

1819
static inline void
1920
registerXLAAutoDiffInterfaces(mlir::DialectRegistry &registry) {
21+
registerEnzymeXLADialectAutoDiffInterface(registry);
2022
registerMHLODialectAutoDiffInterface(registry);
2123
registerStableHLODialectAutoDiffInterface(registry);
2224
registerCHLODialectAutoDiffInterface(registry);

0 commit comments

Comments
 (0)