1- // ===- CHLOAutoDiffOpInterfaceImpl .cpp - Interface external model ---- ----===//
1+ // ===- EnzymeXLAAutoDiffOpInterfaceImpl .cpp - Interface external model ----===//
22//
33// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44// See https://llvm.org/LICENSE.txt for license information.
77// ===----------------------------------------------------------------------===//
88//
99// This file contains the external model implementation of the automatic
10- // differentiation op interfaces for the upstream MLIR arithmetic dialect.
10+ // differentiation op interfaces for the EnzymeXLA dialect.
1111//
1212// ===----------------------------------------------------------------------===//
1313
@@ -192,6 +192,31 @@ struct GPUWrapperOpInterfaceReverse
192192 MGradientUtilsReverse *gutils) const {}
193193};
194194
195+ class Pointer2MemrefRev : public ReverseAutoDiffOpInterface ::ExternalModel<
196+ Pointer2MemrefRev, enzymexla::Pointer2MemrefOp> {
197+ public:
198+ LogicalResult createReverseModeAdjoint (Operation *orig, OpBuilder &builder,
199+ MGradientUtilsReverse *gutils,
200+ SmallVector<Value> caches) const {
201+ return success ();
202+ }
203+
204+ SmallVector<Value> cacheValues (Operation *orig,
205+ MGradientUtilsReverse *gutils) const {
206+ return SmallVector<Value>();
207+ }
208+
209+ void createShadowValues (Operation *op, OpBuilder &builder,
210+ MGradientUtilsReverse *gutils) const {
211+ auto p2m = cast<enzymexla::Pointer2MemrefOp>(op);
212+ if (!gutils->isConstantValue (p2m)) {
213+ Value dres = gutils->invertPointerM (p2m.getSource (), builder);
214+ Value shadow = builder.create <enzymexla::Pointer2MemrefOp>(
215+ p2m.getLoc (), p2m.getType (), dres);
216+ gutils->setDiffe (p2m, shadow, builder);
217+ }
218+ }
219+ };
195220} // namespace
196221
197222void mlir::enzyme::registerEnzymeXLADialectAutoDiffInterface (
@@ -200,6 +225,7 @@ void mlir::enzyme::registerEnzymeXLADialectAutoDiffInterface(
200225 registerInterfaces (context);
201226 GPUWrapperOp::attachInterface<GPUWrapperOpInterfaceReverse>(*context);
202227 GPUWrapperOp::attachInterface<GPUWrapperOpEnzymeOpsRemover>(*context);
228+ enzymexla::Pointer2MemrefOp::attachInterface<Pointer2MemrefRev>(*context);
203229 context->loadDialect <stablehlo::StablehloDialect>();
204230 });
205231}
0 commit comments