Skip to content

Commit 693e185

Browse files
ptiedewsmoses
andauthored
Correct atan2 gradient (#2066)
* Correct atan2 gradient There was a minus sign offset * Refactor HLODerivative for Atan2Op calculations * Correct gradient for forward pass as well * Fix sign in CheckedDiv calculation * Update HLODerivatives.td * Remove HLODerivative for Atan2Op Removed HLODerivative definition for Atan2Op. * fix * fix --------- Co-authored-by: William Moses <gh@wsmoses.com>
1 parent b3ab8a3 commit 693e185

File tree

2 files changed

+26
-29
lines changed

2 files changed

+26
-29
lines changed

src/enzyme_ad/jax/Implementations/HLODerivatives.td

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -969,12 +969,12 @@ def : HLODerivative<"AddOp", (Op $x, $y),
969969
(Add (Shadow $x), (Shadow $y))
970970
>;
971971

972-
def : HLODerivative<"Atan2Op", (Op $x, $y),
972+
def : HLODerivative<"Atan2Op", (Op $y, $x),
973973
[
974-
(CheckedMul (DiffeRet), (Div (Neg $y), (Add (Pow $x, (HLOConstantFP<"2">)), (Pow $y, (HLOConstantFP<"2"> $y))))),
975-
(CheckedMul (DiffeRet), (Div $x, (Add (Pow $x, (HLOConstantFP<"2">)), (Pow $y, (HLOConstantFP<"2"> $y)))))
974+
(CheckedMul (DiffeRet), (Div $x, (Add (Mul $x, $x), (Mul $y, $y)))),
975+
(CheckedMul (DiffeRet), (Div (Neg $y), (Add (Mul $x, $x), (Mul $y, $y)))),
976976
],
977-
(CheckedDiv (Sub (Mul $x, (Shadow $y)), (Mul $y, (Shadow $x))), (Add (Pow $x, (HLOConstantFP<"2">)), (Pow $y, (HLOConstantFP<"2"> $y))))
977+
(CheckedDiv (Sub (Mul $x, (Shadow $y)), (Mul $y, (Shadow $x))), (Add (Mul $x, $x), (Mul $y, $y)))
978978
>;
979979

980980
def : HLOReadOnlyIdentityOp<"BroadcastInDimOp">;

test/lit_tests/diffrules/stablehlo/atan2.mlir

Lines changed: 22 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -7,35 +7,32 @@ func.func @main(%a : tensor<2xf32>, %b : tensor<2xf32>) -> tensor<2xf32> {
77
}
88

99
// FORWARD: func.func @main(%arg0: tensor<2xf32>, %arg1: tensor<2xf32>, %arg2: tensor<2xf32>, %arg3: tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) {
10-
// FORWARD-NEXT: %0 = stablehlo.multiply %arg0, %arg3 : tensor<2xf32>
11-
// FORWARD-NEXT: %1 = stablehlo.multiply %arg2, %arg1 : tensor<2xf32>
10+
// FORWARD-NEXT: %0 = stablehlo.multiply %arg2, %arg1 : tensor<2xf32>
11+
// FORWARD-NEXT: %1 = stablehlo.multiply %arg0, %arg3 : tensor<2xf32>
1212
// FORWARD-NEXT: %2 = stablehlo.subtract %0, %1 : tensor<2xf32>
13-
// FORWARD-NEXT: %cst = stablehlo.constant dense<2.000000e+00> : tensor<2xf32>
14-
// FORWARD-NEXT: %3 = stablehlo.power %arg0, %cst : tensor<2xf32>
15-
// FORWARD-NEXT: %cst_0 = stablehlo.constant dense<2.000000e+00> : tensor<2xf32>
16-
// FORWARD-NEXT: %4 = stablehlo.power %arg2, %cst_0 : tensor<2xf32>
13+
// FORWARD-NEXT: %3 = stablehlo.multiply %arg2, %arg2 : tensor<2xf32>
14+
// FORWARD-NEXT: %4 = stablehlo.multiply %arg0, %arg0 : tensor<2xf32>
1715
// FORWARD-NEXT: %5 = stablehlo.add %3, %4 : tensor<2xf32>
1816
// FORWARD-NEXT: %6 = stablehlo.divide %2, %5 : tensor<2xf32>
1917
// FORWARD-NEXT: %7 = stablehlo.atan2 %arg0, %arg2 : tensor<2xf32>
2018
// FORWARD-NEXT: return %7, %6 : tensor<2xf32>, tensor<2xf32>
2119
// FORWARD-NEXT: }
2220

23-
// REVERSE: func.func @main(%arg0: tensor<2xf32>, %arg1: tensor<2xf32>, %arg2: tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) {
24-
// REVERSE-NEXT: %cst = stablehlo.constant dense<2.000000e+00> : tensor<2xf32>
25-
// REVERSE-NEXT: %cst_0 = arith.constant dense<0.000000e+00> : tensor<2xf32>
26-
// REVERSE-NEXT: %0 = arith.addf %arg2, %cst_0 : tensor<2xf32>
27-
// REVERSE-NEXT: %1 = stablehlo.negate %arg1 : tensor<2xf32>
28-
// REVERSE-NEXT: %2 = stablehlo.power %arg0, %cst : tensor<2xf32>
29-
// REVERSE-NEXT: %3 = stablehlo.power %arg1, %cst : tensor<2xf32>
30-
// REVERSE-NEXT: %4 = stablehlo.add %2, %3 : tensor<2xf32>
31-
// REVERSE-NEXT: %5 = stablehlo.divide %1, %4 : tensor<2xf32>
32-
// REVERSE-NEXT: %6 = stablehlo.multiply %0, %5 : tensor<2xf32>
33-
// REVERSE-NEXT: %7 = arith.addf %6, %cst_0 : tensor<2xf32>
34-
// REVERSE-NEXT: %8 = stablehlo.power %arg0, %cst : tensor<2xf32>
35-
// REVERSE-NEXT: %9 = stablehlo.power %arg1, %cst : tensor<2xf32>
36-
// REVERSE-NEXT: %10 = stablehlo.add %8, %9 : tensor<2xf32>
37-
// REVERSE-NEXT: %11 = stablehlo.divide %arg0, %10 : tensor<2xf32>
38-
// REVERSE-NEXT: %12 = stablehlo.multiply %0, %11 : tensor<2xf32>
39-
// REVERSE-NEXT: %13 = arith.addf %12, %cst_0 : tensor<2xf32>
40-
// REVERSE-NEXT: return %7, %13 : tensor<2xf32>, tensor<2xf32>
41-
// REVERSE-NEXT: }
21+
// REVERSE: func.func @main(%arg0: tensor<2xf32>, %arg1: tensor<2xf32>, %arg2: tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) {
22+
// REVERSE-NEXT: %cst = arith.constant dense<0.000000e+00> : tensor<2xf32>
23+
// REVERSE-NEXT: %0 = arith.addf %arg2, %cst : tensor<2xf32>
24+
// REVERSE-NEXT: %1 = stablehlo.multiply %arg1, %arg1 : tensor<2xf32>
25+
// REVERSE-NEXT: %2 = stablehlo.multiply %arg0, %arg0 : tensor<2xf32>
26+
// REVERSE-NEXT: %3 = stablehlo.add %1, %2 : tensor<2xf32>
27+
// REVERSE-NEXT: %4 = stablehlo.divide %arg1, %3 : tensor<2xf32>
28+
// REVERSE-NEXT: %5 = stablehlo.multiply %0, %4 : tensor<2xf32>
29+
// REVERSE-NEXT: %6 = arith.addf %5, %cst : tensor<2xf32>
30+
// REVERSE-NEXT: %7 = stablehlo.negate %arg0 : tensor<2xf32>
31+
// REVERSE-NEXT: %8 = stablehlo.multiply %arg1, %arg1 : tensor<2xf32>
32+
// REVERSE-NEXT: %9 = stablehlo.multiply %arg0, %arg0 : tensor<2xf32>
33+
// REVERSE-NEXT: %10 = stablehlo.add %8, %9 : tensor<2xf32>
34+
// REVERSE-NEXT: %11 = stablehlo.divide %7, %10 : tensor<2xf32>
35+
// REVERSE-NEXT: %12 = stablehlo.multiply %0, %11 : tensor<2xf32>
36+
// REVERSE-NEXT: %13 = arith.addf %12, %cst : tensor<2xf32>
37+
// REVERSE-NEXT: return %6, %13 : tensor<2xf32>, tensor<2xf32>
38+
// REVERSE-NEXT: }

0 commit comments

Comments
 (0)