@@ -394,121 +394,6 @@ Status FractionalMaxPoolGradHelper(const Scope& scope, const Operation& op,
394394}
395395REGISTER_GRADIENT_OP (" FractionalMaxPool" , FractionalMaxPoolGradHelper);
396396
397- // Templated constructor for FusedBatchNormGrad[..]::Attrs.
398- template <typename T>
399- T FusedBatchNormGradAttrs (float epsilon, std::string data_format,
400- bool is_training) {
401- T result;
402- result.epsilon_ = epsilon;
403- result.data_format_ = data_format;
404- result.is_training_ = is_training;
405- return result;
406- }
407-
408- using BatchNormGradFn =
409- std::function<Status(const Scope&, Output x, Output grad_y, Output scale,
410- const std::vector<Output>& reserve_spaces,
411- float epsilon, std::string data_format,
412- bool is_training, std::vector<Output>* grad_outputs)>;
413-
414- Status BaseFusedBatchNormGrad (const Scope& scope, const Operation& op,
415- const std::vector<Output>& grad_inputs,
416- BatchNormGradFn grad_fn,
417- std::vector<Output>* grad_outputs) {
418- if (op.num_outputs () < 5 ) {
419- return errors::InvalidArgument (
420- " FusedBatchNorm requires at least 5 outputs" );
421- }
422- if (grad_inputs.empty ()) {
423- return errors::InvalidArgument (" FusedBatchNorm grad requires 1 grad input" );
424- }
425- if (op.num_inputs () < 3 ) {
426- return errors::InvalidArgument (" FusedBatchNorm has too few inputs" );
427- }
428-
429- Output x = op.input (0 );
430- Output grad_y = grad_inputs[0 ];
431- Output scale = op.input (1 );
432- float epsilon;
433- std::string data_format;
434- bool is_training;
435- TF_RETURN_IF_ERROR (GetNodeAttr (op.node ()->attrs (), " epsilon" , &epsilon));
436- TF_RETURN_IF_ERROR (
437- GetNodeAttr (op.node ()->attrs (), " data_format" , &data_format));
438- TF_RETURN_IF_ERROR (
439- GetNodeAttr (op.node ()->attrs (), " is_training" , &is_training));
440-
441- std::vector<Output> reserve_spaces;
442- reserve_spaces.push_back (op.output (3 ));
443- reserve_spaces.push_back (op.output (4 ));
444- if (op.num_outputs () > 5 ) {
445- reserve_spaces.push_back (op.output (5 ));
446- }
447-
448- if (is_training) {
449- return grad_fn (scope, x, grad_y, scale, reserve_spaces, epsilon,
450- data_format, is_training, grad_outputs);
451- } else {
452- if (op.num_inputs () < 5 ) {
453- return errors::InvalidArgument (
454- " FusedBatchNorm requires 5 inputs in eval mode" );
455- }
456-
457- reserve_spaces[0 ] = op.input (3 ); // pop_mean
458- reserve_spaces[1 ] = op.input (4 ); // pop_var
459- if (data_format == " NCHW" ) {
460- x = Transpose (scope, x, {0 , 2 , 3 , 1 });
461- grad_y = Transpose (scope, grad_y, {0 , 2 , 3 , 1 });
462- } else if (data_format == " NCDHW" ) {
463- x = Transpose (scope, x, {0 , 2 , 3 , 4 , 1 });
464- grad_y = Transpose (scope, grad_y, {0 , 2 , 3 , 4 , 1 });
465- }
466-
467- std::string target_data_format;
468- if (data_format == " NCHW" || data_format == " NHWC" ) {
469- target_data_format = " NHWC" ;
470- } else {
471- target_data_format = " NDHWC" ;
472- }
473-
474- TF_RETURN_IF_ERROR (grad_fn (scope, x, grad_y, scale, reserve_spaces, epsilon,
475- target_data_format, is_training, grad_outputs));
476- if (data_format == " NCHW" ) {
477- (*grad_outputs)[0 ] = Transpose (scope, (*grad_outputs)[0 ], {0 , 3 , 1 , 2 });
478- } else if (data_format == " NCDHW" ) {
479- (*grad_outputs)[0 ] =
480- Transpose (scope, (*grad_outputs)[0 ], {0 , 4 , 1 , 2 , 3 });
481- }
482- return scope.status ();
483- }
484- }
485-
486- Status FusedBatchNormV3Grad (const Scope& scope, const Operation& op,
487- const std::vector<Output>& grad_inputs,
488- std::vector<Output>* grad_outputs) {
489- return BaseFusedBatchNormGrad (
490- scope, op, grad_inputs,
491- [](const Scope& scope, Output x, Output grad_y, Output scale,
492- const std::vector<Output>& reserve_spaces, float epsilon,
493- std::string data_format, bool is_training,
494- std::vector<Output>* grad_outputs) {
495- FusedBatchNormGradV3 grad (
496- scope, grad_y, x, scale, reserve_spaces[0 ], reserve_spaces[1 ],
497- reserve_spaces[2 ],
498- FusedBatchNormGradAttrs<FusedBatchNormGradV3::Attrs>(
499- epsilon, data_format, is_training));
500- grad_outputs->push_back (grad.x_backprop );
501- grad_outputs->push_back (grad.scale_backprop );
502- grad_outputs->push_back (grad.offset_backprop );
503- grad_outputs->push_back (NoGradient ());
504- grad_outputs->push_back (NoGradient ());
505- return scope.status ();
506- },
507- grad_outputs);
508- }
509-
510- REGISTER_GRADIENT_OP (" FusedBatchNormV3" , FusedBatchNormV3Grad);
511-
512397Status Conv2DBackpropInputGrad (const Scope& scope, const Operation& op,
513398 const std::vector<Output>& grad_inputs,
514399 std::vector<Output>* grad_outputs) {
0 commit comments