Skip to content

Commit 2144852

Browse files
authored
[CHERRY-PICK] Reduce grad fix cherrypick (#32742)
* base changes for fix * minor change * fix for bwd kernel * removed unnecessary import * implemented reviewers suggestions * CI fix
1 parent 9a589de commit 2144852

File tree

5 files changed

+79
-73
lines changed

5 files changed

+79
-73
lines changed

paddle/fluid/operators/reduce_ops/mkldnn/reduce_mean_mkldnn_op.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,8 @@ class ReduceMeanGradMKLDNNKernel : public ReduceGradMKLDNNKernel<T> {
4545
number_of_elements = input_x->numel();
4646
}
4747

48-
this->RunKernel(ctx, dnnl::algorithm::binary_add, 0.0f,
48+
this->RunKernel(ctx, dnnl::algorithm::binary_add,
49+
dnnl::algorithm::reduction_mean, 0.0f,
4950
1.0L / number_of_elements);
5051
}
5152
};

paddle/fluid/operators/reduce_ops/mkldnn/reduce_mkldnn_op.h

Lines changed: 60 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,27 @@ using paddle::framework::LoDTensor;
2121
using paddle::framework::Tensor;
2222
using platform::to_void_cast;
2323

24+
inline std::vector<int64_t> CalculateReducedDims(const Tensor* input,
25+
const Tensor* output,
26+
std::vector<int>& reduce_dims,
27+
bool reduce_all,
28+
bool keep_dim) {
29+
if (keep_dim) return framework::vectorize(output->dims());
30+
31+
if (reduce_all)
32+
return std::vector<int64_t>(framework::vectorize(input->dims()).size(), 1);
33+
34+
std::vector<int64_t> output_dims(framework::vectorize(input->dims()));
35+
for (size_t i = 0; i < reduce_dims.size(); ++i) {
36+
reduce_dims[i] = (reduce_dims[i] >= 0)
37+
? reduce_dims[i]
38+
: input->dims().size() + reduce_dims[i];
39+
output_dims[reduce_dims[i]] = 1;
40+
}
41+
42+
return output_dims;
43+
}
44+
2445
template <typename T>
2546
class ReduceMKLDNNKernel : public framework::OpKernel<T> {
2647
public:
@@ -37,9 +58,8 @@ class ReduceMKLDNNKernel : public framework::OpKernel<T> {
3758
bool reduce_all = ctx.Attr<bool>("reduce_all");
3859
bool keep_dim = ctx.Attr<bool>("keep_dim");
3960

40-
std::vector<int64_t> output_dims =
41-
CalculateOutputDims(input, output, reduce_dims, reduce_all, keep_dim);
42-
61+
auto output_dims =
62+
CalculateReducedDims(input, output, reduce_dims, reduce_all, keep_dim);
4363
auto input_dims = framework::vectorize(input->dims());
4464

4565
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
@@ -96,53 +116,63 @@ class ReduceMKLDNNKernel : public framework::OpKernel<T> {
96116
paddle::framework::vectorize<int64_t>(output->dims()))));
97117
}
98118
}
99-
100-
private:
101-
std::vector<int64_t> CalculateOutputDims(const Tensor* input,
102-
const Tensor* output,
103-
std::vector<int>& reduce_dims,
104-
bool reduce_all,
105-
bool keep_dim) const {
106-
if (keep_dim) return framework::vectorize(output->dims());
107-
108-
if (reduce_all)
109-
return std::vector<int64_t>(framework::vectorize(input->dims()).size(),
110-
1);
111-
112-
std::vector<int64_t> output_dims(framework::vectorize(input->dims()));
113-
for (size_t i = 0; i < reduce_dims.size(); ++i) {
114-
reduce_dims[i] = (reduce_dims[i] >= 0)
115-
? reduce_dims[i]
116-
: input->dims().size() + reduce_dims[i];
117-
output_dims[reduce_dims[i]] = 1;
118-
}
119-
120-
return output_dims;
121-
}
122119
};
123120

124121
template <typename T>
125122
class ReduceGradMKLDNNKernel : public framework::OpKernel<T> {
126123
public:
127124
void RunKernel(const framework::ExecutionContext& ctx,
128-
dnnl::algorithm binary_type, float scale_x,
129-
float scale_y) const {
125+
dnnl::algorithm binary_type, dnnl::algorithm reduction_type,
126+
float scale_x, float scale_y) const {
130127
const auto& dev_ctx =
131128
ctx.template device_context<platform::MKLDNNDeviceContext>();
132129
const auto& onednn_engine = dev_ctx.GetEngine();
133130

131+
bool keep_dim = ctx.Attr<bool>("keep_dim");
132+
bool reduce_all = ctx.Attr<bool>("reduce_all");
134133
auto dims = ctx.Attr<std::vector<int>>("dim");
135134
auto* input_dy = ctx.Input<Tensor>(framework::GradVarName("Out"));
136135
auto* output_dx = ctx.Output<Tensor>(framework::GradVarName("X"));
137136

137+
mkldnn::memory::format_tag x_format_tag;
138+
auto input_dims =
139+
CalculateReducedDims(output_dx, input_dy, dims, reduce_all, keep_dim);
140+
141+
if (input_dims != framework::vectorize(output_dx->dims())) {
142+
const std::string key_pd =
143+
platform::CreateKey(
144+
dev_ctx, framework::vectorize(output_dx->dims()),
145+
ctx.InputName("X"),
146+
(std::to_string(static_cast<int>(reduction_type)))) +
147+
"@fwd_pd";
148+
std::shared_ptr<dnnl::reduction::primitive_desc> fwd_pd =
149+
std::static_pointer_cast<dnnl::reduction::primitive_desc>(
150+
dev_ctx.GetBlob(key_pd));
151+
152+
PADDLE_ENFORCE_NOT_NULL(
153+
fwd_pd, platform::errors::Unavailable(
154+
"Forward primitive descriptor is not available in %s op, "
155+
"cannot deduce memory format tag",
156+
ctx.Type()));
157+
158+
x_format_tag = platform::GetMKLDNNFormat(fwd_pd->src_desc());
159+
160+
PADDLE_ENFORCE_NE(x_format_tag, mkldnn::memory::format_tag::undef,
161+
platform::errors::InvalidArgument(
162+
"Cannot deduce format tag for %s op", ctx.Type()));
163+
} else { // fwd descriptor not available because reorder was used instead
164+
// of reduction
165+
x_format_tag = getPlainFormatTag(output_dx);
166+
}
167+
138168
output_dx->mutable_data<T>(ctx.GetPlace());
139-
output_dx->set_format(getPlainFormatTag(output_dx));
169+
output_dx->set_format(x_format_tag);
140170
output_dx->set_layout(input_dy->layout());
141171

142172
platform::BroadcastDataMKLDNNHandler<T> handler(
143173
binary_type, dev_ctx, onednn_engine, ctx.GetPlace(), output_dx,
144174
input_dy, scale_x, scale_y,
145-
ctx.InputName(framework::GradVarName("Out")));
175+
ctx.InputName(framework::GradVarName("Out")), input_dims);
146176

147177
const auto src_dx_memory = handler.AcquireSrcMemory(output_dx);
148178
const auto src_dy_memory = handler.AcquireSecondSrcMemory(input_dy);

paddle/fluid/operators/reduce_ops/mkldnn/reduce_sum_mkldnn_op.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@ template <typename T>
2929
class ReduceSumGradMKLDNNKernel : public ReduceGradMKLDNNKernel<T> {
3030
public:
3131
void Compute(const framework::ExecutionContext& ctx) const override {
32-
this->RunKernel(ctx, dnnl::algorithm::binary_add, 0.0f, 1.0f);
32+
this->RunKernel(ctx, dnnl::algorithm::binary_add,
33+
dnnl::algorithm::reduction_sum, 0.0f, 1.0f);
3334
}
3435
};
3536

paddle/fluid/operators/reduce_ops/reduce_op.h

Lines changed: 5 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -559,27 +559,18 @@ class ReduceGradOp : public framework::OperatorWithKernel {
559559
protected:
560560
framework::OpKernelType GetExpectedKernelType(
561561
const framework::ExecutionContext& ctx) const override {
562-
auto input_data_type = OperatorWithKernel::IndicateVarDataType(
563-
ctx, framework::GradVarName("Out"));
562+
int in_dtype = ctx.Attr<int>("in_dtype");
563+
auto input_data_type =
564+
(in_dtype >= 0) ? static_cast<framework::proto::VarType::Type>(in_dtype)
565+
: OperatorWithKernel::IndicateVarDataType(
566+
ctx, framework::GradVarName("Out"));
564567

565568
#ifdef PADDLE_WITH_MKLDNN
566569
auto CanMKLDNNReduceGradBeUsed = [&]() {
567570
auto dx_dims = ctx.Input<Tensor>("X")->dims();
568571

569572
if (dx_dims.size() > 5) return false; // max 5D tensor is supported
570573

571-
if (ctx.Attr<bool>("reduce_all") ||
572-
((int)ctx.Attr<std::vector<int>>("dim").size() == dx_dims.size()))
573-
return true;
574-
575-
auto dy_dims = ctx.Input<Tensor>(framework::GradVarName("Out"))->dims();
576-
577-
// Subtensor must be on rightmost part of the bigger tensor
578-
for (int i = 0; i < dy_dims.size(); ++i) {
579-
if (dx_dims[dx_dims.size() - dy_dims.size() + i] != dy_dims[i]) {
580-
return false;
581-
}
582-
}
583574
return true;
584575
};
585576
if (this->CanMKLDNNBeUsed(ctx, input_data_type) &&
@@ -590,12 +581,6 @@ class ReduceGradOp : public framework::OperatorWithKernel {
590581
}
591582
#endif
592583

593-
int in_dtype = ctx.Attr<int>("in_dtype");
594-
if (in_dtype >= 0) {
595-
return framework::OpKernelType(
596-
static_cast<framework::proto::VarType::Type>(in_dtype),
597-
ctx.GetPlace());
598-
}
599584
return framework::OpKernelType(input_data_type, ctx.GetPlace());
600585
}
601586
};

paddle/fluid/platform/mkldnn_reuse.h

Lines changed: 10 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -639,7 +639,8 @@ class BroadcastDataMKLDNNHandler
639639
const mkldnn::engine engine,
640640
platform::Place cpu_place, const Tensor* x,
641641
const Tensor* y, float scale_x, float scale_y,
642-
const std::string& uniq_name)
642+
const std::string& uniq_name,
643+
std::vector<int64_t>& input_dims)
643644
: platform::MKLDNNHandlerT<T, dnnl::binary>(
644645
dev_ctx, engine, cpu_place,
645646
platform::CreateKey(dev_ctx, framework::vectorize(x->dims()),
@@ -659,24 +660,12 @@ class BroadcastDataMKLDNNHandler
659660
y->format(), MKLDNNMemoryFormat::undef,
660661
platform::errors::InvalidArgument("Wrong format set for Y tensor."));
661662

662-
auto src1_tz = framework::vectorize(y->dims());
663663
const auto src0_tz = framework::vectorize(x->dims());
664664

665-
// GetExpectedKernelType checks if smaller vector is a subvector with all
666-
// the dims in correct order on the rightmost part of the bigger vector,
667-
// i.e. a correct vector for broadcasting:
668-
// x = 5, 7, 3, 2, 4, 8
669-
// y = 4, 8
670-
src1_tz.reserve(src0_tz.size());
671-
672-
for (size_t i = src1_tz.size(); i < src0_tz.size(); ++i) {
673-
src1_tz.insert(src1_tz.begin(), 1L);
674-
}
675-
676665
const auto src0_md = dnnl::memory::desc(
677666
src0_tz, platform::MKLDNNGetDataType<T>(), x->format());
678667
const auto src1_md = dnnl::memory::desc(
679-
src1_tz, platform::MKLDNNGetDataType<T>(), x->format());
668+
input_dims, platform::MKLDNNGetDataType<T>(), x->format());
680669

681670
dnnl::primitive_attr attributes;
682671
attributes.set_scales(DNNL_ARG_SRC_0, 0, {scale_x});
@@ -711,7 +700,7 @@ class ReductionMKLDNNHandler
711700
const mkldnn::engine engine, platform::Place cpu_place,
712701
const Tensor* x, const Tensor* y,
713702
const std::string& uniq_name,
714-
std::vector<int64_t> output_dims)
703+
std::vector<int64_t> y_tz)
715704
: platform::MKLDNNHandlerT<T, dnnl::reduction>(
716705
dev_ctx, engine, cpu_place,
717706
platform::CreateKey(dev_ctx, framework::vectorize(x->dims()),
@@ -725,14 +714,14 @@ class ReductionMKLDNNHandler
725714
x->format(), MKLDNNMemoryFormat::undef,
726715
platform::errors::InvalidArgument("Wrong format set for X tensor."));
727716

728-
const auto src_tz = framework::vectorize(x->dims());
717+
const auto x_tz = framework::vectorize(x->dims());
729718

730-
const auto src_md = dnnl::memory::desc(
731-
src_tz, platform::MKLDNNGetDataType<T>(), x->format());
732-
const auto dst_md = memory::desc(
733-
output_dims, platform::MKLDNNGetDataType<T>(), x->format());
719+
const auto x_md = dnnl::memory::desc(
720+
x_tz, platform::MKLDNNGetDataType<T>(), x->format());
721+
const auto y_md =
722+
memory::desc(y_tz, platform::MKLDNNGetDataType<T>(), x->format());
734723

735-
this->AcquireForwardPrimitiveDescriptor(algo, src_md, dst_md, p, eps);
724+
this->AcquireForwardPrimitiveDescriptor(algo, x_md, y_md, p, eps);
736725
}
737726
}
738727
};

0 commit comments

Comments
 (0)