Skip to content

Commit 1c111ff

Browse files
committed
[Op] Fix an issue with SelectV2 gradient broadcasting.
The Where3 op used previously doesn't have the same broadcasting semantics as SelectV2, which caused an error when computing gradients. Cherry-pick from TensorFlow 42fad1e
1 parent ea6da1d commit 1c111ff

File tree

2 files changed

+11
-2
lines changed

2 files changed

+11
-2
lines changed

tensorflow/cc/gradients/math_grad.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1221,7 +1221,7 @@ Status SelectV2Grad(const Scope& scope, const Operation& op,
12211221
auto y = op.input(2);
12221222

12231223
auto zeros = ZerosLike(scope, grad_inputs[0]);
1224-
auto gx = Where3(scope, c, grad_inputs[0], zeros);
1224+
auto gx = SelectV2(scope, c, grad_inputs[0], zeros);
12251225
auto x_shape = Shape(scope, x);
12261226
auto output_shape = Shape(scope, op.output(0));
12271227

@@ -1231,7 +1231,7 @@ Status SelectV2Grad(const Scope& scope, const Operation& op,
12311231
ReduceSum(scope, gx, /*axis=*/reduce_x.r0, ReduceSum::KeepDims(true));
12321232
auto gx_sum_reshape = Reshape(scope, gx_sum, x_shape);
12331233

1234-
auto gy = Where3(scope, c, zeros, grad_inputs[0]);
1234+
auto gy = SelectV2(scope, c, zeros, grad_inputs[0]);
12351235
auto y_shape = Shape(scope, y);
12361236

12371237
// Reduce away broadcasted leading dims.

tensorflow/cc/gradients/math_grad_test.cc

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1016,6 +1016,15 @@ TEST_F(NaryGradTest, SelectV2_Broadcast) {
10161016
RunTest({x, y}, {x_shape, y_shape}, {z}, {x_shape});
10171017
}
10181018

1019+
TEST_F(NaryGradTest, SelectV2_Broadcast2) {
1020+
TensorShape x_shape({2, 3});
1021+
auto cond = Const<bool>(scope_, {{false}, {true}});
1022+
auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
1023+
auto y = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
1024+
auto z = SelectV2(scope_, cond, x, y);
1025+
RunTest({x, y}, {x_shape, x_shape}, {z}, {x_shape});
1026+
}
1027+
10191028
TEST_F(NaryGradTest, Atan2Grad) {
10201029
TensorShape shape({3, 2, 5});
10211030
auto x1 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));

0 commit comments

Comments
 (0)