Skip to content

Commit ea6da1d

Browse files
committed
[Op] Add C++ gradient op for Atan2.
Cherry-pick from TensorFlow 5dee95f
1 parent 3af9ecd commit ea6da1d

File tree

2 files changed

+22
-0
lines changed

2 files changed

+22
-0
lines changed

tensorflow/cc/gradients/math_grad.cc

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,19 @@ Status AtanhGrad(const Scope& scope, const Operation& op,
249249
}
250250
REGISTER_GRADIENT_OP("Atanh", AtanhGrad);
251251

252+
Status Atan2Grad(const Scope& scope, const Operation& op,
253+
const std::vector<Output>& grad_inputs,
254+
std::vector<Output>* grad_outputs) {
255+
auto y = op.input(0);
256+
auto x = op.input(1);
257+
Output grad_inv = Div(scope, grad_inputs[0],
258+
Add(scope, Square(scope, x), Square(scope, y)));
259+
grad_outputs->push_back(Mul(scope, x, grad_inv));
260+
grad_outputs->push_back(Mul(scope, Neg(scope, y), grad_inv));
261+
return scope.status();
262+
}
263+
REGISTER_GRADIENT_OP("Atan2", Atan2Grad);
264+
252265
Status SigmoidGrad(const Scope& scope, const Operation& op,
253266
const std::vector<Output>& grad_inputs,
254267
std::vector<Output>* grad_outputs) {

tensorflow/cc/gradients/math_grad_test.cc

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ using ops::Abs;
3131
using ops::Add;
3232
using ops::AddN;
3333
using ops::AddV2;
34+
using ops::Atan2;
3435
using ops::BatchMatMul;
3536
using ops::Cast;
3637
using ops::Const;
@@ -1015,5 +1016,13 @@ TEST_F(NaryGradTest, SelectV2_Broadcast) {
10151016
RunTest({x, y}, {x_shape, y_shape}, {z}, {x_shape});
10161017
}
10171018

1019+
TEST_F(NaryGradTest, Atan2Grad) {
1020+
TensorShape shape({3, 2, 5});
1021+
auto x1 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));
1022+
auto x2 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));
1023+
auto y = Atan2(scope_, x1, x2);
1024+
RunTest({x1, x2}, {shape, shape}, {y}, {shape});
1025+
}
1026+
10181027
} // namespace
10191028
} // namespace tensorflow

0 commit comments

Comments
 (0)