Skip to content

Commit b0a34c3

Browse files
committed
enzymexla.pointer2memref derivative
1 parent bce12a6 commit b0a34c3

File tree

2 files changed

+30
-2
lines changed

2 files changed

+30
-2
lines changed

src/enzyme_ad/jax/Implementations/EnzymeXLAAutoDiffOpInterfaceImpl.cpp

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
//===- CHLOAutoDiffOpInterfaceImpl.cpp - Interface external model --------===//
1+
//===- EnzymeXLAAutoDiffOpInterfaceImpl.cpp - Interface external model ----===//
22
//
33
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44
// See https://llvm.org/LICENSE.txt for license information.
@@ -7,7 +7,7 @@
77
//===----------------------------------------------------------------------===//
88
//
99
// This file contains the external model implementation of the automatic
10-
// differentiation op interfaces for the upstream MLIR arithmetic dialect.
10+
// differentiation op interfaces for the EnzymeXLA dialect.
1111
//
1212
//===----------------------------------------------------------------------===//
1313

@@ -192,6 +192,31 @@ struct GPUWrapperOpInterfaceReverse
192192
MGradientUtilsReverse *gutils) const {}
193193
};
194194

195+
class Pointer2MemrefRev : public ReverseAutoDiffOpInterface::ExternalModel<
196+
Pointer2MemrefRev, enzymexla::Pointer2MemrefOp> {
197+
public:
198+
LogicalResult createReverseModeAdjoint(Operation *orig, OpBuilder &builder,
199+
MGradientUtilsReverse *gutils,
200+
SmallVector<Value> caches) const {
201+
return success();
202+
}
203+
204+
SmallVector<Value> cacheValues(Operation *orig,
205+
MGradientUtilsReverse *gutils) const {
206+
return SmallVector<Value>();
207+
}
208+
209+
void createShadowValues(Operation *op, OpBuilder &builder,
210+
MGradientUtilsReverse *gutils) const {
211+
auto p2m = cast<enzymexla::Pointer2MemrefOp>(op);
212+
if (!gutils->isConstantValue(p2m)) {
213+
Value dres = gutils->invertPointerM(p2m.getSource(), builder);
214+
Value shadow = builder.create<enzymexla::Pointer2MemrefOp>(
215+
p2m.getLoc(), p2m.getType(), dres);
216+
gutils->setDiffe(p2m, shadow, builder);
217+
}
218+
}
219+
};
195220
} // namespace
196221

197222
void mlir::enzyme::registerEnzymeXLADialectAutoDiffInterface(
@@ -200,6 +225,7 @@ void mlir::enzyme::registerEnzymeXLADialectAutoDiffInterface(
200225
registerInterfaces(context);
201226
GPUWrapperOp::attachInterface<GPUWrapperOpInterfaceReverse>(*context);
202227
GPUWrapperOp::attachInterface<GPUWrapperOpEnzymeOpsRemover>(*context);
228+
enzymexla::Pointer2MemrefOp::attachInterface<Pointer2MemrefRev>(*context);
203229
context->loadDialect<stablehlo::StablehloDialect>();
204230
});
205231
}

src/enzyme_ad/jax/Implementations/XLADerivatives.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,15 @@
1010

1111
namespace mlir {
1212
namespace enzyme {
13+
void registerEnzymeXLADialectAutoDiffInterface(mlir::DialectRegistry &registry);
1314
void registerMHLODialectAutoDiffInterface(mlir::DialectRegistry &registry);
1415
void registerStableHLODialectAutoDiffInterface(mlir::DialectRegistry &registry);
1516
void registerCHLODialectAutoDiffInterface(mlir::DialectRegistry &registry);
1617
void registerEnzymeXLADialectAutoDiffInterface(mlir::DialectRegistry &registry);
1718

1819
static inline void
1920
registerXLAAutoDiffInterfaces(mlir::DialectRegistry &registry) {
21+
registerEnzymeXLADialectAutoDiffInterface(registry);
2022
registerMHLODialectAutoDiffInterface(registry);
2123
registerStableHLODialectAutoDiffInterface(registry);
2224
registerCHLODialectAutoDiffInterface(registry);

0 commit comments

Comments
 (0)