Skip to content

Commit 1e664a8

Browse files
committed
[Op] Add C++ gradients for UnsortedSegmentMin/Max/Sum.
Cherry-pick from TensorFlow 2e99f65
1 parent 1c111ff commit 1e664a8

File tree

2 files changed

+145
-0
lines changed

2 files changed

+145
-0
lines changed

tensorflow/cc/gradients/math_grad.cc

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@ limitations under the License.
1616
#define _USE_MATH_DEFINES
1717
#include <cmath>
1818

19+
#include "tensorflow/cc/ops/array_ops.h"
1920
#include "tensorflow/cc/ops/array_ops_internal.h"
21+
#include "tensorflow/cc/ops/math_ops.h"
2022
#include "tensorflow/cc/ops/math_ops_internal.h"
2123
#include "tensorflow/cc/ops/standard_ops.h"
2224

@@ -1248,6 +1250,101 @@ Status SelectV2Grad(const Scope& scope, const Operation& op,
12481250

12491251
REGISTER_GRADIENT_OP("SelectV2", SelectV2Grad);
12501252

1253+
// Helper function for unsorted segment ops.
1254+
// Returns 'ids' with negative elements replaced by 0.
1255+
Output GetZeroClippedIndices(const Scope& scope, const Output& ids) {
1256+
return Maximum(scope, ids, ZerosLike(scope, ids));
1257+
}
1258+
1259+
// Helper function for unsorted segment ops.
1260+
// Returns a mask of where 'ids' are positive, reshaped so that it will be
1261+
// broadcastable to the result shape of gathering params by ids.
1262+
Output GetIsPositive(const Scope& scope, const Output& params,
1263+
const Output& ids) {
1264+
Output is_positive = GreaterEqual(scope, ids, ZerosLike(scope, ids));
1265+
// tf.where(condition, x, y) requires condition to have the same shape as x
1266+
// and y.
1267+
Output is_positive_shape = Shape(scope, is_positive);
1268+
Output ones =
1269+
Tile(scope, Const(scope, {1}), Subtract(scope, Rank(scope, params), {1}));
1270+
auto broadcastable_shape = Concat(scope, {is_positive_shape, ones},
1271+
/*axis=*/0);
1272+
is_positive = Reshape(scope, is_positive, broadcastable_shape);
1273+
is_positive = LogicalAnd(scope, is_positive, OnesLike(scope, is_positive));
1274+
return is_positive;
1275+
}
1276+
1277+
// Helper function for unsorted segment ops.
1278+
// Gathers params for positive segment ids and gathers 0 for inputs with
1279+
// negative segment id.
1280+
Output GatherDropNegatives(const Scope& scope, const Output& params,
1281+
Output& zero_clipped_indices, Output& is_positive) {
1282+
auto gathered = Gather(scope, params, zero_clipped_indices);
1283+
// Replace gathered params of negative indices with 0.
1284+
auto zero_slice = ZerosLike(scope, gathered);
1285+
return SelectV2(scope, is_positive, gathered, zero_slice);
1286+
}
1287+
1288+
Status UnsortedSegmentMinOrMaxGrad(const Scope& scope, const Operation& op,
1289+
const std::vector<Output>& grad_inputs,
1290+
std::vector<Output>* grad_outputs) {
1291+
if (op.num_inputs() != 3) {
1292+
return errors::InvalidArgument("UnsortedSegmentMax requires 3 arguments");
1293+
}
1294+
1295+
if (grad_inputs.size() != 1) {
1296+
return errors::InvalidArgument(
1297+
"UnsortedSegmentMax grad requires 1 grad input");
1298+
}
1299+
1300+
auto grad = grad_inputs[0];
1301+
// Get the number of selected (minimum or maximum) elements in each segment.
1302+
auto zero_clipped_indices = GetZeroClippedIndices(scope, op.input(1));
1303+
auto is_positive = GetIsPositive(scope, op.output(0), op.input(1));
1304+
Output gathered_outputs = GatherDropNegatives(
1305+
scope, op.output(0), zero_clipped_indices, is_positive);
1306+
Output is_selected = Equal(scope, op.input(0), gathered_outputs);
1307+
is_selected = LogicalAnd(scope, is_selected, is_positive);
1308+
auto num_selected = UnsortedSegmentSum(
1309+
scope, Cast(scope, is_selected, grad.type()), op.input(1), op.input(2));
1310+
// Compute the gradient for each segment.The gradient for the ith segment is
1311+
// divided evenly among the selected elements in that segment.
1312+
auto weighted_grads = Div(scope, grad, num_selected);
1313+
auto gathered_grads = GatherDropNegatives(scope, weighted_grads,
1314+
zero_clipped_indices, is_positive);
1315+
auto zeros = ZerosLike(scope, gathered_grads);
1316+
grad_outputs->push_back(SelectV2(scope, is_selected, gathered_grads, zeros));
1317+
grad_outputs->push_back(NoGradient());
1318+
grad_outputs->push_back(NoGradient());
1319+
return scope.status();
1320+
}
1321+
1322+
REGISTER_GRADIENT_OP("UnsortedSegmentMax", UnsortedSegmentMinOrMaxGrad);
1323+
REGISTER_GRADIENT_OP("UnsortedSegmentMin", UnsortedSegmentMinOrMaxGrad);
1324+
1325+
Status UnsortedSegmentSumGrad(const Scope& scope, const Operation& op,
1326+
const std::vector<Output>& grad_inputs,
1327+
std::vector<Output>* grad_outputs) {
1328+
if (op.num_inputs() != 3) {
1329+
return errors::InvalidArgument("UnsortedSegmentSum requires 3 arguments");
1330+
}
1331+
1332+
if (grad_inputs.size() != 1) {
1333+
return errors::InvalidArgument(
1334+
"UnsortedSegmentSum grad requires 1 grad input");
1335+
}
1336+
1337+
auto zero_clipped_indices = GetZeroClippedIndices(scope, op.input(1));
1338+
auto is_positive = GetIsPositive(scope, grad_inputs[0], op.input(1));
1339+
grad_outputs->push_back(GatherDropNegatives(
1340+
scope, grad_inputs[0], zero_clipped_indices, is_positive));
1341+
grad_outputs->push_back(NoGradient());
1342+
grad_outputs->push_back(NoGradient());
1343+
return scope.status();
1344+
}
1345+
1346+
REGISTER_GRADIENT_OP("UnsortedSegmentSum", UnsortedSegmentSumGrad);
1347+
12511348
} // anonymous namespace
12521349
} // namespace ops
12531350
} // namespace tensorflow

tensorflow/cc/gradients/math_grad_test.cc

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,9 @@ using ops::SelectV2;
5454
using ops::SquaredDifference;
5555
using ops::Sub;
5656
using ops::Sum;
57+
using ops::UnsortedSegmentMax;
58+
using ops::UnsortedSegmentMin;
59+
using ops::UnsortedSegmentSum;
5760
using ops::Where3;
5861

5962
// TODO(andydavis) Test gradient function against numeric gradients output.
@@ -1033,5 +1036,50 @@ TEST_F(NaryGradTest, Atan2Grad) {
10331036
RunTest({x1, x2}, {shape, shape}, {y}, {shape});
10341037
}
10351038

1039+
TEST_F(NaryGradTest, UnsortedSegmentMaxGrad) {
1040+
TensorShape shape({3, 2, 5});
1041+
auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));
1042+
auto segment_ids = Const(scope_, {0, 0, 1});
1043+
auto y = UnsortedSegmentMax(scope_, x, segment_ids, /*num_segments=*/2);
1044+
TensorShape y_shape({2, 2, 5});
1045+
RunTest({x}, {shape}, {y}, {y_shape});
1046+
}
1047+
1048+
TEST_F(NaryGradTest, UnsortedSegmentMaxGrad_Int64Ids) {
1049+
TensorShape shape({3, 2, 5});
1050+
auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));
1051+
auto segment_ids = Const(scope_, {0ll, 0ll, 1ll});
1052+
auto y = UnsortedSegmentMax(scope_, x, segment_ids, /*num_segments=*/2);
1053+
TensorShape y_shape({2, 2, 5});
1054+
RunTest({x}, {shape}, {y}, {y_shape});
1055+
}
1056+
1057+
TEST_F(NaryGradTest, UnsortedSegmentMaxGrad_NegativeIds) {
1058+
TensorShape shape({3, 2, 5});
1059+
auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));
1060+
auto segment_ids = Const(scope_, {0, 0, -1});
1061+
auto y = UnsortedSegmentMax(scope_, x, segment_ids, /*num_segments=*/1);
1062+
TensorShape y_shape({1, 2, 5});
1063+
RunTest({x}, {shape}, {y}, {y_shape});
1064+
}
1065+
1066+
TEST_F(NaryGradTest, UnsortedSegmentMinGrad) {
1067+
TensorShape shape({3, 2, 5});
1068+
auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));
1069+
auto segment_ids = Const(scope_, {0, 0, 1});
1070+
auto y = UnsortedSegmentMin(scope_, x, segment_ids, /*num_segments=*/2);
1071+
TensorShape y_shape({2, 2, 5});
1072+
RunTest({x}, {shape}, {y}, {y_shape});
1073+
}
1074+
1075+
TEST_F(NaryGradTest, UnsortedSegmentSumGrad) {
1076+
TensorShape shape({3, 2, 5});
1077+
auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));
1078+
auto segment_ids = Const(scope_, {0, 0, 1});
1079+
auto y = UnsortedSegmentSum(scope_, x, segment_ids, /*num_segments=*/2);
1080+
TensorShape y_shape({2, 2, 5});
1081+
RunTest({x}, {shape}, {y}, {y_shape});
1082+
}
1083+
10361084
} // namespace
10371085
} // namespace tensorflow

0 commit comments

Comments
 (0)