Skip to content

Commit 46c5805

Browse files
committed
[Op] Add C++ gradient op for Select.
Cherry-pick from TensorFlow a0dab8a
1 parent f4b57ae commit 46c5805

File tree

2 files changed

+29
-0
lines changed

2 files changed

+29
-0
lines changed

tensorflow/cc/gradients/math_grad.cc

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1173,6 +1173,25 @@ Status CastGrad(const Scope& scope, const Operation& op,
11731173
}
11741174
REGISTER_GRADIENT_OP("Cast", CastGrad);
11751175

1176+
Status SelectGrad(const Scope& scope, const Operation& op,
1177+
const std::vector<Output>& grad_inputs,
1178+
std::vector<Output>* grad_outputs) {
1179+
if (op.num_inputs() != 3) {
1180+
return errors::InvalidArgument("Select requires 3 arguments");
1181+
}
1182+
if (grad_inputs.size() != 1) {
1183+
return errors::InvalidArgument("Select grad requires 1 grad input");
1184+
}
1185+
1186+
auto c = op.input(0);
1187+
auto zeros = ZerosLike(scope, grad_inputs[0]);
1188+
grad_outputs->push_back(NoGradient()); // Condition
1189+
grad_outputs->push_back(Where3(scope, c, grad_inputs[0], zeros));
1190+
grad_outputs->push_back(Where3(scope, c, zeros, grad_inputs[0]));
1191+
return scope.status();
1192+
}
1193+
REGISTER_GRADIENT_OP("Select", SelectGrad);
1194+
11761195
} // anonymous namespace
11771196
} // namespace ops
11781197
} // namespace tensorflow

tensorflow/cc/gradients/math_grad_test.cc

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ using ops::SegmentSum;
5252
using ops::SquaredDifference;
5353
using ops::Sub;
5454
using ops::Sum;
55+
using ops::Where3;
5556

5657
// TODO(andydavis) Test gradient function against numeric gradients output.
5758
// TODO(andydavis) As more gradients are added move common test functions
@@ -985,5 +986,14 @@ TEST_F(NaryGradTest, CastGrad) {
985986
EXPECT_LT(max_error, 1e-3);
986987
}
987988

989+
TEST_F(NaryGradTest, Select) {
990+
TensorShape shape({1, 3});
991+
auto cond = Const<bool>(scope_, {{false, true, true}});
992+
auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));
993+
auto y = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));
994+
auto z = Where3(scope_, cond, x, y);
995+
RunTest({x, y}, {shape, shape}, {z}, {shape});
996+
}
997+
988998
} // namespace
989999
} // namespace tensorflow

0 commit comments

Comments
 (0)