1- // ===- CHLOAutoDiffOpInterfaceImpl .cpp - Interface external model ---- ----===//
1+ // ===- EnzymeXLAAutoDiffOpInterfaceImpl .cpp - Interface external model ----===//
22//
33// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44// See https://llvm.org/LICENSE.txt for license information.
77// ===----------------------------------------------------------------------===//
88//
99// This file contains the external model implementation of the automatic
10- // differentiation op interfaces for the upstream MLIR arithmetic dialect.
10+ // differentiation op interfaces for the EnzymeXLA dialect.
1111//
1212// ===----------------------------------------------------------------------===//
1313
@@ -193,6 +193,31 @@ struct GPUWrapperOpInterfaceReverse
193193 MGradientUtilsReverse *gutils) const {}
194194};
195195
196+ class Pointer2MemrefRev : public ReverseAutoDiffOpInterface ::ExternalModel<
197+ Pointer2MemrefRev, enzymexla::Pointer2MemrefOp> {
198+ public:
199+ LogicalResult createReverseModeAdjoint (Operation *orig, OpBuilder &builder,
200+ MGradientUtilsReverse *gutils,
201+ SmallVector<Value> caches) const {
202+ return success ();
203+ }
204+
205+ SmallVector<Value> cacheValues (Operation *orig,
206+ MGradientUtilsReverse *gutils) const {
207+ return SmallVector<Value>();
208+ }
209+
210+ void createShadowValues (Operation *op, OpBuilder &builder,
211+ MGradientUtilsReverse *gutils) const {
212+ auto p2m = cast<enzymexla::Pointer2MemrefOp>(op);
213+ if (!gutils->isConstantValue (p2m)) {
214+ Value dres = gutils->invertPointerM (p2m.getSource (), builder);
215+ Value shadow = builder.create <enzymexla::Pointer2MemrefOp>(
216+ p2m.getLoc (), p2m.getType (), dres);
217+ gutils->setDiffe (p2m, shadow, builder);
218+ }
219+ }
220+ };
196221} // namespace
197222
198223void mlir::enzyme::registerEnzymeXLADialectAutoDiffInterface (
@@ -201,6 +226,7 @@ void mlir::enzyme::registerEnzymeXLADialectAutoDiffInterface(
201226 registerInterfaces (context);
202227 GPUWrapperOp::attachInterface<GPUWrapperOpInterfaceReverse>(*context);
203228 GPUWrapperOp::attachInterface<GPUWrapperOpEnzymeOpsRemover>(*context);
229+ enzymexla::Pointer2MemrefOp::attachInterface<Pointer2MemrefRev>(*context);
204230
205231 // Register batching interfaces
206232 JITCallOp::attachInterface<SHLOGenericBatchOpInterface<JITCallOp>>(
0 commit comments