@@ -120,6 +120,20 @@ Status SplitGrad(const Scope& scope, const Operation& op,
120120}
121121REGISTER_GRADIENT_OP (" Split" , SplitGrad);
122122
123+ Status SplitVGrad (const Scope& scope, const Operation& op,
124+ const std::vector<Output>& grad_inputs,
125+ std::vector<Output>* grad_outputs) {
126+ if (op.num_inputs () < 3 ) {
127+ return errors::InvalidArgument (" SplitV requires 3 arguments" );
128+ }
129+ grad_outputs->push_back (Concat (scope, grad_inputs, op.input (2 )));
130+ for (int i = 0 ; i < op.num_inputs () - 1 ; ++i) {
131+ grad_outputs->push_back (NoGradient ());
132+ }
133+ return scope.status ();
134+ }
135+ REGISTER_GRADIENT_OP (" SplitV" , SplitVGrad);
136+
123137Status FillGrad (const Scope& scope, const Operation& op,
124138 const std::vector<Output>& grad_inputs,
125139 std::vector<Output>* grad_outputs) {
@@ -491,6 +505,262 @@ Status SliceGrad(const Scope& scope, const Operation& op,
491505}
492506REGISTER_GRADIENT_OP (" Slice" , SliceGrad);
493507
508+ Status ConcatGradHelper (const Scope& scope, const Operation& op,
509+ const std::vector<Output>& grad_inputs,
510+ std::vector<Output>* grad_outputs,
511+ int start_value_index, int end_value_index,
512+ int dim_index) {
513+ if (end_value_index >= op.num_inputs ()) {
514+ return errors::Internal (" Invalid input index" );
515+ }
516+ std::vector<Output> inputs;
517+ for (int i = start_value_index; i < end_value_index; ++i) {
518+ inputs.push_back (op.input (i));
519+ }
520+
521+ auto shapes = ShapeN (scope, inputs);
522+ const auto unique_name = scope.GetUniqueNameForOp (" ConcatOffset" );
523+ auto builder =
524+ ::tensorflow::NodeBuilder (unique_name, " ConcatOffset" )
525+ .Input(::tensorflow::ops::AsNodeOut(scope, op.input(dim_index)))
526+ .Input(::tensorflow::ops::AsNodeOutList(scope, shapes.output));
527+ scope.UpdateBuilder (&builder);
528+ ::tensorflow::Node* concat_offset_node;
529+ scope.UpdateStatus (builder.Finalize (scope.graph (), &concat_offset_node));
530+ scope.UpdateStatus (scope.DoShapeInference (concat_offset_node));
531+ if (concat_offset_node->num_outputs () != inputs.size ()) {
532+ return errors::Internal (" ConcatOffset has invalid output count" );
533+ }
534+ if (grad_inputs.size () != 1 ) {
535+ return errors::InvalidArgument (" Concat grad should have 1 input" );
536+ }
537+
538+ // For each dx[i], we take a slice of dy. The offset and size of the
539+ // slice is given by offset[i] and shape[i].
540+ const Output& dy = grad_inputs[0 ];
541+ for (int i = 0 ; i < inputs.size (); ++i) {
542+ grad_outputs->push_back (
543+ Slice (scope, dy, Output (concat_offset_node, i), shapes.output [i]));
544+ }
545+
546+ // Insert a NoGradient for the axis.
547+ grad_outputs->insert (grad_outputs->begin () + dim_index, NoGradient ());
548+ return scope.status ();
549+ }
550+
551+ Status ConcatV2Grad (const Scope& scope, const Operation& op,
552+ const std::vector<Output>& grad_inputs,
553+ std::vector<Output>* grad_outputs) {
554+ return ConcatGradHelper (scope, op, grad_inputs, grad_outputs,
555+ /* start_value_index=*/ 0 ,
556+ /* end_value_index=*/ op.num_inputs () - 1 ,
557+ /* dim+index=*/ op.num_inputs () - 1 );
558+ }
559+
560+ REGISTER_GRADIENT_OP (" ConcatV2" , ConcatV2Grad);
561+
562+ Status BroadcastToGrad (const Scope& scope, const Operation& op,
563+ const std::vector<Output>& grad_inputs,
564+ std::vector<Output>* grad_outputs) {
565+ if (grad_inputs.size () != 1 ) {
566+ return errors::InvalidArgument (" BroadcastTo grad should have 1 grad input" );
567+ }
568+ if (op.num_inputs () != 2 ) {
569+ return errors::InvalidArgument (" BroadcastTo requires 2 inputs" );
570+ }
571+
572+ auto x_shape = Shape (scope, op.input (0 ));
573+ auto args = internal::BroadcastGradientArgs (scope, x_shape, op.input (1 ));
574+ auto sum_gx = Sum (scope, grad_inputs[0 ], args.r0 );
575+ grad_outputs->push_back (Reshape (scope, sum_gx, x_shape));
576+ grad_outputs->push_back (NoGradient ());
577+ return scope.status ();
578+ }
579+
580+ REGISTER_GRADIENT_OP (" BroadcastTo" , BroadcastToGrad);
581+
582+ Status TileGrad (const Scope& scope, const Operation& op,
583+ const std::vector<Output>& grad_inputs,
584+ std::vector<Output>* grad_outputs) {
585+ if (op.num_inputs () != 2 ) {
586+ return errors::InvalidArgument (" Tile requires 2 inputs" );
587+ }
588+ if (grad_inputs.size () != 1 ) {
589+ return errors::InvalidArgument (" Tile grad requires 1 grad input" );
590+ }
591+
592+ Shape::Attrs shape_attrs;
593+ shape_attrs.out_type_ = op.input_type (1 );
594+ auto input_shape = Shape (scope, op.input (0 ), shape_attrs);
595+ // We interleave multiples and input_shape to get split_shape,
596+ // reshape grad to split_shape, and reduce along all even
597+ // dimensions (the tiled dimensions) to get the result
598+ // with shape input_shape. For example
599+ // input_shape = [20, 30, 40]
600+ // multiples = [2, 3, 4]
601+ // split_shape = [2, 20, 3, 30, 4, 40]
602+ // axes = [0, 2, 4]
603+ auto stack = Stack (scope, {op.input (1 ), input_shape.output });
604+ auto perm = Range (scope, Sub (scope, Rank (scope, stack), 1 ), -1 , -1 );
605+ auto split_shape = Reshape (scope, Transpose (scope, stack, perm), {-1 });
606+ auto axes = Range (scope, Const (scope, 0 ), Size (scope, split_shape.output ), 2 );
607+ auto input_grad = ReduceSum (
608+ scope, Reshape (scope, grad_inputs[0 ], split_shape.output ), axes.output );
609+ grad_outputs->push_back (input_grad.output );
610+ grad_outputs->push_back (NoGradient ());
611+ return scope.status ();
612+ }
613+ REGISTER_GRADIENT_OP (" Tile" , TileGrad);
614+
615+ // Create a constant of the provided d_type;
616+ Output ConstHelper (const Scope& scope, int value, DataType d_type) {
617+ return Cast (scope, Const (scope, value), d_type);
618+ }
619+
620+ // Adds the batch offsets to the given indices and returns the results.
621+ Output GetBatchIndices (const Scope& scope, const Output& params_shape,
622+ const Output& indices, int batch_dims) {
623+ Output batch_indices = indices;
624+ auto indices_ndims = Rank (scope, indices);
625+ auto casted_params_shape = Cast (scope, params_shape, indices.type ());
626+ Output accum_dim_value = ConstHelper (scope, 1 , indices.type ());
627+ for (int dim = batch_dims; dim > 0 ; dim--) {
628+ Output dim_value = Slice (scope, casted_params_shape, {dim - 1 }, {1 });
629+ accum_dim_value = Multiply (scope, accum_dim_value,
630+ Slice (scope, casted_params_shape, {dim}, {1 }));
631+ auto start = ConstHelper (scope, 0 , indices.type ());
632+ auto step = ConstHelper (scope, 1 , indices.type ());
633+ Output dim_indices = Range (scope, start, Squeeze (scope, dim_value), step);
634+ dim_indices = Multiply (scope, dim_indices, accum_dim_value);
635+ auto one = Cast (scope, Const (scope, {1 }), indices.type ());
636+ auto dim_shape = Concat (
637+ scope,
638+ {Output (Tile (scope, one, Const (scope, {dim - 1 }))), dim_value,
639+ Output (Tile (scope, one,
640+ ExpandDims (scope, Sub (scope, indices_ndims, dim), 0 )))},
641+ /* axis=*/ 0 );
642+ batch_indices =
643+ Add (scope, batch_indices, Reshape (scope, dim_indices, dim_shape));
644+ }
645+
646+ return batch_indices;
647+ }
648+
649+ Output BatchGatherGrad (const Scope& scope, Output params_shape, Output values,
650+ Output indices, int batch_dims, Output gather_dim_size) {
651+ // Axis is the first non-batch dimension.
652+ auto indices_size = ExpandDims (scope, Size (scope, indices), 0 );
653+ Output outer_shape, flat_values_shape;
654+ if (batch_dims != 0 ) {
655+ auto values_shape = Shape (scope, values);
656+ // Add the batch offsets to indices and flatten the batch dimensions.
657+ outer_shape = Slice (scope, values_shape, {0 }, {batch_dims});
658+ auto inner_shape =
659+ Slice (scope, Slice (scope, values_shape, {batch_dims}, {-1 }), {1 }, {-1 });
660+ auto batch_size = Prod (scope, outer_shape, /* axis=*/ 0 );
661+ flat_values_shape = Concat (scope, {{-1 }, inner_shape}, /* axis=*/ 0 );
662+ gather_dim_size = Multiply (scope, gather_dim_size, batch_size);
663+ indices = GetBatchIndices (scope, params_shape, indices, batch_dims);
664+ values = Reshape (scope, values, flat_values_shape);
665+ }
666+
667+ indices = Reshape (scope, indices, indices_size);
668+ Output params_grad =
669+ UnsortedSegmentSum (scope, values, indices, gather_dim_size);
670+
671+ if (batch_dims != 0 ) {
672+ // Put back the batch dimensions.
673+ params_grad = Reshape (scope, params_grad, params_shape);
674+ }
675+ return params_grad;
676+ }
677+
678+ Status GatherV2Grad (const Scope& scope, const Operation& op,
679+ const std::vector<Output>& grad_inputs,
680+ std::vector<Output>* grad_outputs) {
681+ if (op.num_inputs () != 3 ) {
682+ return errors::InvalidArgument (" Gather requires 3 inputs" );
683+ }
684+ if (grad_inputs.size () != 1 ) {
685+ return errors::InvalidArgument (" Gather grad requires 1 grad input" );
686+ }
687+
688+ // params can be large, so colocate the shape calculation with it.
689+ // params can be very large for sparse model, array_ops.shape raises
690+ // exception on the Windows platform when any dimension is larger than
691+ // int32. params_shape is not used in optimizer apply_sparse gradients,
692+ // so it's fine to convert it back to int32 regardless of truncation.
693+ auto params = op.input (0 );
694+ auto colocate_scope = scope.ColocateWith (params);
695+ Shape::Attrs shape_attrs;
696+ shape_attrs.out_type_ = DT_INT64;
697+ auto params_shape64 = Shape (colocate_scope, params, shape_attrs);
698+ Output params_shape = Cast (colocate_scope, params_shape64, DT_INT32);
699+
700+ auto indices = op.input (1 );
701+ auto indices_size = ExpandDims (scope, Size (scope, indices), 0 );
702+ auto axis = op.input (2 );
703+ auto axis_expand = ExpandDims (scope, axis, 0 );
704+
705+ int batch_dims;
706+ TF_RETURN_IF_ERROR (
707+ GetNodeAttr (op.node ()->attrs (), " batch_dims" , &batch_dims));
708+ if (batch_dims < 0 ) {
709+ // TODO(bdodson): Figure out if we can find the param rank here, like the
710+ // python implementation does.
711+ return errors::InvalidArgument (
712+ " C++ GatherV2 gradient does not support negative batch_dims." );
713+ }
714+
715+ // Handle axis by transposing the axis dimension to be the first non-batch
716+ // dimension, compute the gradient and transpose the result back.
717+ auto outer_shape = Slice (scope, params_shape, {0 }, axis_expand);
718+ auto inner_shape =
719+ Slice (scope, Slice (scope, params_shape, axis_expand, {-1 }), {1 }, {-1 });
720+ auto values_shape = Concat (scope, {outer_shape, {-1 }, inner_shape}, 0 );
721+ auto values_dims = Size (scope, values_shape);
722+ auto axis_dims = Size (scope, outer_shape);
723+
724+ Output outer_batches_indices = Range (scope, 0 , batch_dims, /* delta=*/ 1 );
725+ Output batch_axis_indices = Range (scope, batch_dims, axis_dims, /* delta=*/ 1 );
726+ Output inner_axes_indices =
727+ Range (scope, Add (scope, axis_dims, 1 ), values_dims, /* delta=*/ 1 );
728+ Output axis_dims_expand = ExpandDims (scope, axis_dims, 0 );
729+
730+ auto values = Reshape (scope, grad_inputs[0 ], values_shape);
731+
732+ // Move values[axis] up to values[batch_dims]
733+ Output transpose_dims = Concat (scope,
734+ {outer_batches_indices, axis_dims_expand,
735+ batch_axis_indices, inner_axes_indices},
736+ 0 );
737+ auto values_transpose = Transpose (scope, values, transpose_dims);
738+ Output gather_dim_size =
739+ Squeeze (scope, Slice (scope, params_shape, axis_expand, {1 }));
740+ params_shape = Gather (scope, params_shape, transpose_dims);
741+
742+ auto params_grad = BatchGatherGrad (scope, params_shape, values_transpose,
743+ indices, batch_dims, gather_dim_size);
744+
745+ // Inverts the above transpose by moving dimension batch_dims back to its
746+ // original position.
747+ Output invert_transpose_dims = Concat (scope,
748+ {outer_batches_indices,
749+ Add (scope, batch_axis_indices, 1 ),
750+ {batch_dims},
751+ inner_axes_indices},
752+ 0 );
753+
754+ params_grad = Transpose (scope, params_grad, invert_transpose_dims);
755+
756+ grad_outputs->push_back (params_grad);
757+ grad_outputs->push_back (NoGradient ());
758+ grad_outputs->push_back (NoGradient ());
759+ return scope.status ();
760+ }
761+
762+ REGISTER_GRADIENT_OP (" GatherV2" , GatherV2Grad);
763+
494764} // anonymous namespace
495765} // namespace ops
496766} // namespace tensorflow
0 commit comments