Skip to content

Commit ac3be7e

Browse files
tyb0807BuildKite
andauthored
Add derivative for LLVM:ExpOp (#2220)
Co-authored-by: BuildKite <[email protected]>
1 parent 96b8efc commit ac3be7e

File tree

4 files changed

+28
-0
lines changed

4 files changed

+28
-0
lines changed

enzyme/Enzyme/MLIR/Implementations/Common.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ def TypeOf : Operation</*primal*/0, /*shadow*/0> {
121121

122122
class ComplexInst<string m> : Inst<m, "complex">;
123123
class ArithInst<string m> : Inst<m, "arith">;
124+
class LlvmInst<string m> : Inst<m, "LLVM">;
124125
class MathInst<string m> : Inst<m, "math">;
125126

126127
def AddF : ArithInst<"AddFOp">;
@@ -133,6 +134,9 @@ def RemF : ArithInst<"RemFOp">;
133134
def CheckedMulF : ArithInst<"MulFOp">;
134135
def CheckedDivF : ArithInst<"DivFOp">;
135136

137+
def LlvmCheckedMulF : LlvmInst<"FMulOp">;
138+
def LlvmExpF : LlvmInst<"ExpOp">;
139+
136140
def CosF : MathInst<"CosOp">;
137141
def SinF : MathInst<"SinOp">;
138142
def ExpF : MathInst<"ExpOp">;

enzyme/Enzyme/MLIR/Implementations/LLVMAutoDiffOpInterfaceImpl.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "Interfaces/AutoDiffOpInterface.h"
1616
#include "Interfaces/AutoDiffTypeInterface.h"
1717
#include "Interfaces/GradientUtils.h"
18+
#include "Interfaces/GradientUtilsReverse.h"
1819
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
1920
#include "mlir/IR/DialectRegistry.h"
2021
#include "mlir/Support/LogicalResult.h"

enzyme/Enzyme/MLIR/Implementations/LLVMDerivatives.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,3 +26,9 @@ def : ReadOnlyIdentityOp<"LLVM", "PtrToIntOp", [0]>;
2626
def : ReadOnlyIdentityOp<"LLVM", "IntToPtrOp", [0]>;
2727

2828
def : AllocationOp<"LLVM", "AllocaOp">;
29+
30+
def : MLIRDerivative<"LLVM", "ExpOp", (Op $x),
31+
[
32+
(LlvmCheckedMulF (DiffeRet), (LlvmExpF $x))
33+
]
34+
>;

enzyme/test/MLIR/ForwardMode/llvm.mlir

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,16 @@ module {
1313
%r = enzyme.fwddiff @square(%x, %dx) { activity=[#enzyme<activity enzyme_dup>], ret_activity=[#enzyme<activity enzyme_dupnoneed>] } : (f64, f64) -> (f64)
1414
return %r : f64
1515
}
16+
17+
func.func @exp(%x: f32) -> f32 {
18+
%0 = llvm.intr.exp(%x) : (f32) -> f32
19+
return %0 : f32
20+
}
21+
22+
func.func @dexp(%x: f32, %dx: f32) -> f32 {
23+
%r = enzyme.fwddiff @exp(%x, %dx) { activity=[#enzyme<activity enzyme_dup>], ret_activity=[#enzyme<activity enzyme_dupnoneed>] } : (f32, f32) -> f32
24+
return %r : f32
25+
}
1626
}
1727

1828
// CHECK: func.func private @fwddiffesquare(%[[arg0:.+]]: f64, %[[arg1:.+]]: f64) -> f64 {
@@ -29,3 +39,10 @@ module {
2939
// CHECK-NEXT: %[[i7:.+]] = llvm.load %[[i1]] : !llvm.ptr -> f64
3040
// CHECK-NEXT: return %[[i6]] : f64
3141
// CHECK-NEXT: }
42+
43+
// CHECK: func.func private @fwddiffeexp(%[[arg0:.+]]: f32, %[[arg1:.+]]: f32) -> f32 {
44+
// CHECK-NEXT: %[[der:.+]] = llvm.intr.exp(%[[arg0]]) : (f32) -> f32
45+
// CHECK-NEXT: %[[res:.+]] = llvm.fmul %[[arg1]], %[[der]] : f32
46+
// CHECK-NEXT: %[[exp:.+]] = llvm.intr.exp(%[[arg0]]) : (f32) -> f32
47+
// CHECK-NEXT: return %[[res]] : f32
48+
// CHECK-NEXT: }

0 commit comments

Comments
 (0)