@@ -16,7 +16,9 @@ limitations under the License.
1616#define _USE_MATH_DEFINES
1717#include < cmath>
1818
19+ #include " tensorflow/cc/ops/array_ops.h"
1920#include " tensorflow/cc/ops/array_ops_internal.h"
21+ #include " tensorflow/cc/ops/math_ops.h"
2022#include " tensorflow/cc/ops/math_ops_internal.h"
2123#include " tensorflow/cc/ops/standard_ops.h"
2224
@@ -1248,6 +1250,101 @@ Status SelectV2Grad(const Scope& scope, const Operation& op,
12481250
12491251REGISTER_GRADIENT_OP (" SelectV2" , SelectV2Grad);
12501252
1253+ // Helper function for unsorted segment ops.
1254+ // Returns 'ids' with negative elements replaced by 0.
1255+ Output GetZeroClippedIndices (const Scope& scope, const Output& ids) {
1256+ return Maximum (scope, ids, ZerosLike (scope, ids));
1257+ }
1258+
1259+ // Helper function for unsorted segment ops.
1260+ // Returns a mask of where 'ids' are positive, reshaped so that it will be
1261+ // broadcastable to the result shape of gathering params by ids.
1262+ Output GetIsPositive (const Scope& scope, const Output& params,
1263+ const Output& ids) {
1264+ Output is_positive = GreaterEqual (scope, ids, ZerosLike (scope, ids));
1265+ // tf.where(condition, x, y) requires condition to have the same shape as x
1266+ // and y.
1267+ Output is_positive_shape = Shape (scope, is_positive);
1268+ Output ones =
1269+ Tile (scope, Const (scope, {1 }), Subtract (scope, Rank (scope, params), {1 }));
1270+ auto broadcastable_shape = Concat (scope, {is_positive_shape, ones},
1271+ /* axis=*/ 0 );
1272+ is_positive = Reshape (scope, is_positive, broadcastable_shape);
1273+ is_positive = LogicalAnd (scope, is_positive, OnesLike (scope, is_positive));
1274+ return is_positive;
1275+ }
1276+
1277+ // Helper function for unsorted segment ops.
1278+ // Gathers params for positive segment ids and gathers 0 for inputs with
1279+ // negative segment id.
1280+ Output GatherDropNegatives (const Scope& scope, const Output& params,
1281+ Output& zero_clipped_indices, Output& is_positive) {
1282+ auto gathered = Gather (scope, params, zero_clipped_indices);
1283+ // Replace gathered params of negative indices with 0.
1284+ auto zero_slice = ZerosLike (scope, gathered);
1285+ return SelectV2 (scope, is_positive, gathered, zero_slice);
1286+ }
1287+
1288+ Status UnsortedSegmentMinOrMaxGrad (const Scope& scope, const Operation& op,
1289+ const std::vector<Output>& grad_inputs,
1290+ std::vector<Output>* grad_outputs) {
1291+ if (op.num_inputs () != 3 ) {
1292+ return errors::InvalidArgument (" UnsortedSegmentMax requires 3 arguments" );
1293+ }
1294+
1295+ if (grad_inputs.size () != 1 ) {
1296+ return errors::InvalidArgument (
1297+ " UnsortedSegmentMax grad requires 1 grad input" );
1298+ }
1299+
1300+ auto grad = grad_inputs[0 ];
1301+ // Get the number of selected (minimum or maximum) elements in each segment.
1302+ auto zero_clipped_indices = GetZeroClippedIndices (scope, op.input (1 ));
1303+ auto is_positive = GetIsPositive (scope, op.output (0 ), op.input (1 ));
1304+ Output gathered_outputs = GatherDropNegatives (
1305+ scope, op.output (0 ), zero_clipped_indices, is_positive);
1306+ Output is_selected = Equal (scope, op.input (0 ), gathered_outputs);
1307+ is_selected = LogicalAnd (scope, is_selected, is_positive);
1308+ auto num_selected = UnsortedSegmentSum (
1309+ scope, Cast (scope, is_selected, grad.type ()), op.input (1 ), op.input (2 ));
1310+ // Compute the gradient for each segment.The gradient for the ith segment is
1311+ // divided evenly among the selected elements in that segment.
1312+ auto weighted_grads = Div (scope, grad, num_selected);
1313+ auto gathered_grads = GatherDropNegatives (scope, weighted_grads,
1314+ zero_clipped_indices, is_positive);
1315+ auto zeros = ZerosLike (scope, gathered_grads);
1316+ grad_outputs->push_back (SelectV2 (scope, is_selected, gathered_grads, zeros));
1317+ grad_outputs->push_back (NoGradient ());
1318+ grad_outputs->push_back (NoGradient ());
1319+ return scope.status ();
1320+ }
1321+
1322+ REGISTER_GRADIENT_OP (" UnsortedSegmentMax" , UnsortedSegmentMinOrMaxGrad);
1323+ REGISTER_GRADIENT_OP (" UnsortedSegmentMin" , UnsortedSegmentMinOrMaxGrad);
1324+
1325+ Status UnsortedSegmentSumGrad (const Scope& scope, const Operation& op,
1326+ const std::vector<Output>& grad_inputs,
1327+ std::vector<Output>* grad_outputs) {
1328+ if (op.num_inputs () != 3 ) {
1329+ return errors::InvalidArgument (" UnsortedSegmentSum requires 3 arguments" );
1330+ }
1331+
1332+ if (grad_inputs.size () != 1 ) {
1333+ return errors::InvalidArgument (
1334+ " UnsortedSegmentSum grad requires 1 grad input" );
1335+ }
1336+
1337+ auto zero_clipped_indices = GetZeroClippedIndices (scope, op.input (1 ));
1338+ auto is_positive = GetIsPositive (scope, grad_inputs[0 ], op.input (1 ));
1339+ grad_outputs->push_back (GatherDropNegatives (
1340+ scope, grad_inputs[0 ], zero_clipped_indices, is_positive));
1341+ grad_outputs->push_back (NoGradient ());
1342+ grad_outputs->push_back (NoGradient ());
1343+ return scope.status ();
1344+ }
1345+
1346+ REGISTER_GRADIENT_OP (" UnsortedSegmentSum" , UnsortedSegmentSumGrad);
1347+
12511348} // anonymous namespace
12521349} // namespace ops
12531350} // namespace tensorflow
0 commit comments