Skip to content

Commit f4b57ae

Browse files
committed
[Op] Add list of gradient implementation for the following ops.
- SplitV - ConcatV2 - BroadcastTo - Tile - GatherV2 - Cumsum - Cast - FusedBatchNormV3 - Conv2DBackpropInput Most implementations are translated from the python version, although ConcatV2 copies the implementation from tensorflow/core/ops/array_grad.cc, Cherry-pick from TensorFlow Commit ID: bf3d89b
1 parent 9590fdd commit f4b57ae

File tree

6 files changed

+722
-0
lines changed

6 files changed

+722
-0
lines changed

tensorflow/cc/gradients/array_grad.cc

Lines changed: 270 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,20 @@ Status SplitGrad(const Scope& scope, const Operation& op,
120120
}
121121
REGISTER_GRADIENT_OP("Split", SplitGrad);
122122

123+
Status SplitVGrad(const Scope& scope, const Operation& op,
124+
const std::vector<Output>& grad_inputs,
125+
std::vector<Output>* grad_outputs) {
126+
if (op.num_inputs() < 3) {
127+
return errors::InvalidArgument("SplitV requires 3 arguments");
128+
}
129+
grad_outputs->push_back(Concat(scope, grad_inputs, op.input(2)));
130+
for (int i = 0; i < op.num_inputs() - 1; ++i) {
131+
grad_outputs->push_back(NoGradient());
132+
}
133+
return scope.status();
134+
}
135+
REGISTER_GRADIENT_OP("SplitV", SplitVGrad);
136+
123137
Status FillGrad(const Scope& scope, const Operation& op,
124138
const std::vector<Output>& grad_inputs,
125139
std::vector<Output>* grad_outputs) {
@@ -491,6 +505,262 @@ Status SliceGrad(const Scope& scope, const Operation& op,
491505
}
492506
REGISTER_GRADIENT_OP("Slice", SliceGrad);
493507

508+
Status ConcatGradHelper(const Scope& scope, const Operation& op,
509+
const std::vector<Output>& grad_inputs,
510+
std::vector<Output>* grad_outputs,
511+
int start_value_index, int end_value_index,
512+
int dim_index) {
513+
if (end_value_index >= op.num_inputs()) {
514+
return errors::Internal("Invalid input index");
515+
}
516+
std::vector<Output> inputs;
517+
for (int i = start_value_index; i < end_value_index; ++i) {
518+
inputs.push_back(op.input(i));
519+
}
520+
521+
auto shapes = ShapeN(scope, inputs);
522+
const auto unique_name = scope.GetUniqueNameForOp("ConcatOffset");
523+
auto builder =
524+
::tensorflow::NodeBuilder(unique_name, "ConcatOffset")
525+
.Input(::tensorflow::ops::AsNodeOut(scope, op.input(dim_index)))
526+
.Input(::tensorflow::ops::AsNodeOutList(scope, shapes.output));
527+
scope.UpdateBuilder(&builder);
528+
::tensorflow::Node* concat_offset_node;
529+
scope.UpdateStatus(builder.Finalize(scope.graph(), &concat_offset_node));
530+
scope.UpdateStatus(scope.DoShapeInference(concat_offset_node));
531+
if (concat_offset_node->num_outputs() != inputs.size()) {
532+
return errors::Internal("ConcatOffset has invalid output count");
533+
}
534+
if (grad_inputs.size() != 1) {
535+
return errors::InvalidArgument("Concat grad should have 1 input");
536+
}
537+
538+
// For each dx[i], we take a slice of dy. The offset and size of the
539+
// slice is given by offset[i] and shape[i].
540+
const Output& dy = grad_inputs[0];
541+
for (int i = 0; i < inputs.size(); ++i) {
542+
grad_outputs->push_back(
543+
Slice(scope, dy, Output(concat_offset_node, i), shapes.output[i]));
544+
}
545+
546+
// Insert a NoGradient for the axis.
547+
grad_outputs->insert(grad_outputs->begin() + dim_index, NoGradient());
548+
return scope.status();
549+
}
550+
551+
Status ConcatV2Grad(const Scope& scope, const Operation& op,
552+
const std::vector<Output>& grad_inputs,
553+
std::vector<Output>* grad_outputs) {
554+
return ConcatGradHelper(scope, op, grad_inputs, grad_outputs,
555+
/*start_value_index=*/0,
556+
/*end_value_index=*/op.num_inputs() - 1,
557+
/*dim+index=*/op.num_inputs() - 1);
558+
}
559+
560+
REGISTER_GRADIENT_OP("ConcatV2", ConcatV2Grad);
561+
562+
Status BroadcastToGrad(const Scope& scope, const Operation& op,
563+
const std::vector<Output>& grad_inputs,
564+
std::vector<Output>* grad_outputs) {
565+
if (grad_inputs.size() != 1) {
566+
return errors::InvalidArgument("BroadcastTo grad should have 1 grad input");
567+
}
568+
if (op.num_inputs() != 2) {
569+
return errors::InvalidArgument("BroadcastTo requires 2 inputs");
570+
}
571+
572+
auto x_shape = Shape(scope, op.input(0));
573+
auto args = internal::BroadcastGradientArgs(scope, x_shape, op.input(1));
574+
auto sum_gx = Sum(scope, grad_inputs[0], args.r0);
575+
grad_outputs->push_back(Reshape(scope, sum_gx, x_shape));
576+
grad_outputs->push_back(NoGradient());
577+
return scope.status();
578+
}
579+
580+
REGISTER_GRADIENT_OP("BroadcastTo", BroadcastToGrad);
581+
582+
Status TileGrad(const Scope& scope, const Operation& op,
583+
const std::vector<Output>& grad_inputs,
584+
std::vector<Output>* grad_outputs) {
585+
if (op.num_inputs() != 2) {
586+
return errors::InvalidArgument("Tile requires 2 inputs");
587+
}
588+
if (grad_inputs.size() != 1) {
589+
return errors::InvalidArgument("Tile grad requires 1 grad input");
590+
}
591+
592+
Shape::Attrs shape_attrs;
593+
shape_attrs.out_type_ = op.input_type(1);
594+
auto input_shape = Shape(scope, op.input(0), shape_attrs);
595+
// We interleave multiples and input_shape to get split_shape,
596+
// reshape grad to split_shape, and reduce along all even
597+
// dimensions (the tiled dimensions) to get the result
598+
// with shape input_shape. For example
599+
// input_shape = [20, 30, 40]
600+
// multiples = [2, 3, 4]
601+
// split_shape = [2, 20, 3, 30, 4, 40]
602+
// axes = [0, 2, 4]
603+
auto stack = Stack(scope, {op.input(1), input_shape.output});
604+
auto perm = Range(scope, Sub(scope, Rank(scope, stack), 1), -1, -1);
605+
auto split_shape = Reshape(scope, Transpose(scope, stack, perm), {-1});
606+
auto axes = Range(scope, Const(scope, 0), Size(scope, split_shape.output), 2);
607+
auto input_grad = ReduceSum(
608+
scope, Reshape(scope, grad_inputs[0], split_shape.output), axes.output);
609+
grad_outputs->push_back(input_grad.output);
610+
grad_outputs->push_back(NoGradient());
611+
return scope.status();
612+
}
613+
REGISTER_GRADIENT_OP("Tile", TileGrad);
614+
615+
// Create a constant of the provided d_type;
616+
Output ConstHelper(const Scope& scope, int value, DataType d_type) {
617+
return Cast(scope, Const(scope, value), d_type);
618+
}
619+
620+
// Adds the batch offsets to the given indices and returns the results.
621+
Output GetBatchIndices(const Scope& scope, const Output& params_shape,
622+
const Output& indices, int batch_dims) {
623+
Output batch_indices = indices;
624+
auto indices_ndims = Rank(scope, indices);
625+
auto casted_params_shape = Cast(scope, params_shape, indices.type());
626+
Output accum_dim_value = ConstHelper(scope, 1, indices.type());
627+
for (int dim = batch_dims; dim > 0; dim--) {
628+
Output dim_value = Slice(scope, casted_params_shape, {dim - 1}, {1});
629+
accum_dim_value = Multiply(scope, accum_dim_value,
630+
Slice(scope, casted_params_shape, {dim}, {1}));
631+
auto start = ConstHelper(scope, 0, indices.type());
632+
auto step = ConstHelper(scope, 1, indices.type());
633+
Output dim_indices = Range(scope, start, Squeeze(scope, dim_value), step);
634+
dim_indices = Multiply(scope, dim_indices, accum_dim_value);
635+
auto one = Cast(scope, Const(scope, {1}), indices.type());
636+
auto dim_shape = Concat(
637+
scope,
638+
{Output(Tile(scope, one, Const(scope, {dim - 1}))), dim_value,
639+
Output(Tile(scope, one,
640+
ExpandDims(scope, Sub(scope, indices_ndims, dim), 0)))},
641+
/*axis=*/0);
642+
batch_indices =
643+
Add(scope, batch_indices, Reshape(scope, dim_indices, dim_shape));
644+
}
645+
646+
return batch_indices;
647+
}
648+
649+
Output BatchGatherGrad(const Scope& scope, Output params_shape, Output values,
650+
Output indices, int batch_dims, Output gather_dim_size) {
651+
// Axis is the first non-batch dimension.
652+
auto indices_size = ExpandDims(scope, Size(scope, indices), 0);
653+
Output outer_shape, flat_values_shape;
654+
if (batch_dims != 0) {
655+
auto values_shape = Shape(scope, values);
656+
// Add the batch offsets to indices and flatten the batch dimensions.
657+
outer_shape = Slice(scope, values_shape, {0}, {batch_dims});
658+
auto inner_shape =
659+
Slice(scope, Slice(scope, values_shape, {batch_dims}, {-1}), {1}, {-1});
660+
auto batch_size = Prod(scope, outer_shape, /*axis=*/0);
661+
flat_values_shape = Concat(scope, {{-1}, inner_shape}, /*axis=*/0);
662+
gather_dim_size = Multiply(scope, gather_dim_size, batch_size);
663+
indices = GetBatchIndices(scope, params_shape, indices, batch_dims);
664+
values = Reshape(scope, values, flat_values_shape);
665+
}
666+
667+
indices = Reshape(scope, indices, indices_size);
668+
Output params_grad =
669+
UnsortedSegmentSum(scope, values, indices, gather_dim_size);
670+
671+
if (batch_dims != 0) {
672+
// Put back the batch dimensions.
673+
params_grad = Reshape(scope, params_grad, params_shape);
674+
}
675+
return params_grad;
676+
}
677+
678+
Status GatherV2Grad(const Scope& scope, const Operation& op,
679+
const std::vector<Output>& grad_inputs,
680+
std::vector<Output>* grad_outputs) {
681+
if (op.num_inputs() != 3) {
682+
return errors::InvalidArgument("Gather requires 3 inputs");
683+
}
684+
if (grad_inputs.size() != 1) {
685+
return errors::InvalidArgument("Gather grad requires 1 grad input");
686+
}
687+
688+
// params can be large, so colocate the shape calculation with it.
689+
// params can be very large for sparse model, array_ops.shape raises
690+
// exception on the Windows platform when any dimension is larger than
691+
// int32. params_shape is not used in optimizer apply_sparse gradients,
692+
// so it's fine to convert it back to int32 regardless of truncation.
693+
auto params = op.input(0);
694+
auto colocate_scope = scope.ColocateWith(params);
695+
Shape::Attrs shape_attrs;
696+
shape_attrs.out_type_ = DT_INT64;
697+
auto params_shape64 = Shape(colocate_scope, params, shape_attrs);
698+
Output params_shape = Cast(colocate_scope, params_shape64, DT_INT32);
699+
700+
auto indices = op.input(1);
701+
auto indices_size = ExpandDims(scope, Size(scope, indices), 0);
702+
auto axis = op.input(2);
703+
auto axis_expand = ExpandDims(scope, axis, 0);
704+
705+
int batch_dims;
706+
TF_RETURN_IF_ERROR(
707+
GetNodeAttr(op.node()->attrs(), "batch_dims", &batch_dims));
708+
if (batch_dims < 0) {
709+
// TODO(bdodson): Figure out if we can find the param rank here, like the
710+
// python implementation does.
711+
return errors::InvalidArgument(
712+
"C++ GatherV2 gradient does not support negative batch_dims.");
713+
}
714+
715+
// Handle axis by transposing the axis dimension to be the first non-batch
716+
// dimension, compute the gradient and transpose the result back.
717+
auto outer_shape = Slice(scope, params_shape, {0}, axis_expand);
718+
auto inner_shape =
719+
Slice(scope, Slice(scope, params_shape, axis_expand, {-1}), {1}, {-1});
720+
auto values_shape = Concat(scope, {outer_shape, {-1}, inner_shape}, 0);
721+
auto values_dims = Size(scope, values_shape);
722+
auto axis_dims = Size(scope, outer_shape);
723+
724+
Output outer_batches_indices = Range(scope, 0, batch_dims, /*delta=*/1);
725+
Output batch_axis_indices = Range(scope, batch_dims, axis_dims, /*delta=*/1);
726+
Output inner_axes_indices =
727+
Range(scope, Add(scope, axis_dims, 1), values_dims, /*delta=*/1);
728+
Output axis_dims_expand = ExpandDims(scope, axis_dims, 0);
729+
730+
auto values = Reshape(scope, grad_inputs[0], values_shape);
731+
732+
// Move values[axis] up to values[batch_dims]
733+
Output transpose_dims = Concat(scope,
734+
{outer_batches_indices, axis_dims_expand,
735+
batch_axis_indices, inner_axes_indices},
736+
0);
737+
auto values_transpose = Transpose(scope, values, transpose_dims);
738+
Output gather_dim_size =
739+
Squeeze(scope, Slice(scope, params_shape, axis_expand, {1}));
740+
params_shape = Gather(scope, params_shape, transpose_dims);
741+
742+
auto params_grad = BatchGatherGrad(scope, params_shape, values_transpose,
743+
indices, batch_dims, gather_dim_size);
744+
745+
// Inverts the above transpose by moving dimension batch_dims back to its
746+
// original position.
747+
Output invert_transpose_dims = Concat(scope,
748+
{outer_batches_indices,
749+
Add(scope, batch_axis_indices, 1),
750+
{batch_dims},
751+
inner_axes_indices},
752+
0);
753+
754+
params_grad = Transpose(scope, params_grad, invert_transpose_dims);
755+
756+
grad_outputs->push_back(params_grad);
757+
grad_outputs->push_back(NoGradient());
758+
grad_outputs->push_back(NoGradient());
759+
return scope.status();
760+
}
761+
762+
REGISTER_GRADIENT_OP("GatherV2", GatherV2Grad);
763+
494764
} // anonymous namespace
495765
} // namespace ops
496766
} // namespace tensorflow

0 commit comments

Comments
 (0)