Skip to content

Commit 9acef18

Browse files
committed
[Op] Rollback c++ gradient op for FusedBatchNormV3Grad.
1 parent 1e664a8 commit 9acef18

File tree

2 files changed

+0
-150
lines changed

2 files changed

+0
-150
lines changed

tensorflow/cc/gradients/nn_grad.cc

Lines changed: 0 additions & 115 deletions
Original file line numberDiff line numberDiff line change
@@ -394,121 +394,6 @@ Status FractionalMaxPoolGradHelper(const Scope& scope, const Operation& op,
394394
}
395395
REGISTER_GRADIENT_OP("FractionalMaxPool", FractionalMaxPoolGradHelper);
396396

397-
// Templated constructor for FusedBatchNormGrad[..]::Attrs.
398-
template <typename T>
399-
T FusedBatchNormGradAttrs(float epsilon, std::string data_format,
400-
bool is_training) {
401-
T result;
402-
result.epsilon_ = epsilon;
403-
result.data_format_ = data_format;
404-
result.is_training_ = is_training;
405-
return result;
406-
}
407-
408-
using BatchNormGradFn =
409-
std::function<Status(const Scope&, Output x, Output grad_y, Output scale,
410-
const std::vector<Output>& reserve_spaces,
411-
float epsilon, std::string data_format,
412-
bool is_training, std::vector<Output>* grad_outputs)>;
413-
414-
Status BaseFusedBatchNormGrad(const Scope& scope, const Operation& op,
415-
const std::vector<Output>& grad_inputs,
416-
BatchNormGradFn grad_fn,
417-
std::vector<Output>* grad_outputs) {
418-
if (op.num_outputs() < 5) {
419-
return errors::InvalidArgument(
420-
"FusedBatchNorm requires at least 5 outputs");
421-
}
422-
if (grad_inputs.empty()) {
423-
return errors::InvalidArgument("FusedBatchNorm grad requires 1 grad input");
424-
}
425-
if (op.num_inputs() < 3) {
426-
return errors::InvalidArgument("FusedBatchNorm has too few inputs");
427-
}
428-
429-
Output x = op.input(0);
430-
Output grad_y = grad_inputs[0];
431-
Output scale = op.input(1);
432-
float epsilon;
433-
std::string data_format;
434-
bool is_training;
435-
TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "epsilon", &epsilon));
436-
TF_RETURN_IF_ERROR(
437-
GetNodeAttr(op.node()->attrs(), "data_format", &data_format));
438-
TF_RETURN_IF_ERROR(
439-
GetNodeAttr(op.node()->attrs(), "is_training", &is_training));
440-
441-
std::vector<Output> reserve_spaces;
442-
reserve_spaces.push_back(op.output(3));
443-
reserve_spaces.push_back(op.output(4));
444-
if (op.num_outputs() > 5) {
445-
reserve_spaces.push_back(op.output(5));
446-
}
447-
448-
if (is_training) {
449-
return grad_fn(scope, x, grad_y, scale, reserve_spaces, epsilon,
450-
data_format, is_training, grad_outputs);
451-
} else {
452-
if (op.num_inputs() < 5) {
453-
return errors::InvalidArgument(
454-
"FusedBatchNorm requires 5 inputs in eval mode");
455-
}
456-
457-
reserve_spaces[0] = op.input(3); // pop_mean
458-
reserve_spaces[1] = op.input(4); // pop_var
459-
if (data_format == "NCHW") {
460-
x = Transpose(scope, x, {0, 2, 3, 1});
461-
grad_y = Transpose(scope, grad_y, {0, 2, 3, 1});
462-
} else if (data_format == "NCDHW") {
463-
x = Transpose(scope, x, {0, 2, 3, 4, 1});
464-
grad_y = Transpose(scope, grad_y, {0, 2, 3, 4, 1});
465-
}
466-
467-
std::string target_data_format;
468-
if (data_format == "NCHW" || data_format == "NHWC") {
469-
target_data_format = "NHWC";
470-
} else {
471-
target_data_format = "NDHWC";
472-
}
473-
474-
TF_RETURN_IF_ERROR(grad_fn(scope, x, grad_y, scale, reserve_spaces, epsilon,
475-
target_data_format, is_training, grad_outputs));
476-
if (data_format == "NCHW") {
477-
(*grad_outputs)[0] = Transpose(scope, (*grad_outputs)[0], {0, 3, 1, 2});
478-
} else if (data_format == "NCDHW") {
479-
(*grad_outputs)[0] =
480-
Transpose(scope, (*grad_outputs)[0], {0, 4, 1, 2, 3});
481-
}
482-
return scope.status();
483-
}
484-
}
485-
486-
Status FusedBatchNormV3Grad(const Scope& scope, const Operation& op,
487-
const std::vector<Output>& grad_inputs,
488-
std::vector<Output>* grad_outputs) {
489-
return BaseFusedBatchNormGrad(
490-
scope, op, grad_inputs,
491-
[](const Scope& scope, Output x, Output grad_y, Output scale,
492-
const std::vector<Output>& reserve_spaces, float epsilon,
493-
std::string data_format, bool is_training,
494-
std::vector<Output>* grad_outputs) {
495-
FusedBatchNormGradV3 grad(
496-
scope, grad_y, x, scale, reserve_spaces[0], reserve_spaces[1],
497-
reserve_spaces[2],
498-
FusedBatchNormGradAttrs<FusedBatchNormGradV3::Attrs>(
499-
epsilon, data_format, is_training));
500-
grad_outputs->push_back(grad.x_backprop);
501-
grad_outputs->push_back(grad.scale_backprop);
502-
grad_outputs->push_back(grad.offset_backprop);
503-
grad_outputs->push_back(NoGradient());
504-
grad_outputs->push_back(NoGradient());
505-
return scope.status();
506-
},
507-
grad_outputs);
508-
}
509-
510-
REGISTER_GRADIENT_OP("FusedBatchNormV3", FusedBatchNormV3Grad);
511-
512397
Status Conv2DBackpropInputGrad(const Scope& scope, const Operation& op,
513398
const std::vector<Output>& grad_inputs,
514399
std::vector<Output>* grad_outputs) {

tensorflow/cc/gradients/nn_grad_test.cc

Lines changed: 0 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@ using ops::Conv2DBackpropInput;
3535
using ops::Elu;
3636
using ops::FractionalAvgPool;
3737
using ops::FractionalMaxPool;
38-
using ops::FusedBatchNormV3;
3938
using ops::L2Loss;
4039
using ops::LogSoftmax;
4140
using ops::LRN;
@@ -344,40 +343,6 @@ TEST_F(NNGradTest, FractionalMaxPoolGradHelper) {
344343
RunTest(x, x_init_value, y.output, y_shape);
345344
}
346345

347-
class FusedBatchNormGradTest : public NNGradTest,
348-
public ::testing::WithParamInterface<
349-
std::tuple<bool, bool, TensorShape>> {};
350-
351-
TEST_P(FusedBatchNormGradTest, FusedBatchNormV3Grad) {
352-
FusedBatchNormV3::Attrs attrs;
353-
attrs.is_training_ = std::get<0>(GetParam());
354-
bool channel_first = std::get<1>(GetParam());
355-
TensorShape shape = std::get<2>(GetParam());
356-
int channel_dim = (channel_first) ? 1 : shape.dims() - 1;
357-
TensorShape scale_shape({shape.dim_size(channel_dim)});
358-
auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));
359-
auto scale = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(scale_shape));
360-
auto offset = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(scale_shape));
361-
auto mean = ops::ZerosLike(scope_, scale);
362-
auto var = ops::OnesLike(scope_, scale);
363-
364-
if (!channel_first) {
365-
attrs.data_format_ = (shape.dims() == 5) ? "NDHWC" : "NHWC";
366-
} else {
367-
attrs.data_format_ = (shape.dims() == 5) ? "NCDHW" : "NCHW";
368-
}
369-
370-
auto y = FusedBatchNormV3(scope_, x, scale, offset, mean, var, attrs);
371-
RunTest({x, scale, offset}, {shape, scale_shape, scale_shape}, {y.y},
372-
{shape});
373-
}
374-
375-
INSTANTIATE_TEST_SUITE_P(
376-
FusedBatchNormGrad, FusedBatchNormGradTest,
377-
::testing::Combine(::testing::Bool(), ::testing::Bool(),
378-
::testing::Values(TensorShape({2, 3, 4, 5}),
379-
TensorShape({2, 3, 2, 2, 2}))));
380-
381346
TEST_F(NNGradTest, Conv2DBackpropInputGrad) {
382347
TensorShape shape({1, 2, 2, 1});
383348
TensorShape filter_shape({1, 1, 1, 1});

0 commit comments

Comments
 (0)