@@ -21,6 +21,27 @@ using paddle::framework::LoDTensor;
21
21
using paddle::framework::Tensor;
22
22
using platform::to_void_cast;
23
23
24
+ inline std::vector<int64_t > CalculateReducedDims (const Tensor* input,
25
+ const Tensor* output,
26
+ std::vector<int >& reduce_dims,
27
+ bool reduce_all,
28
+ bool keep_dim) {
29
+ if (keep_dim) return framework::vectorize (output->dims ());
30
+
31
+ if (reduce_all)
32
+ return std::vector<int64_t >(framework::vectorize (input->dims ()).size (), 1 );
33
+
34
+ std::vector<int64_t > output_dims (framework::vectorize (input->dims ()));
35
+ for (size_t i = 0 ; i < reduce_dims.size (); ++i) {
36
+ reduce_dims[i] = (reduce_dims[i] >= 0 )
37
+ ? reduce_dims[i]
38
+ : input->dims ().size () + reduce_dims[i];
39
+ output_dims[reduce_dims[i]] = 1 ;
40
+ }
41
+
42
+ return output_dims;
43
+ }
44
+
24
45
template <typename T>
25
46
class ReduceMKLDNNKernel : public framework ::OpKernel<T> {
26
47
public:
@@ -37,9 +58,8 @@ class ReduceMKLDNNKernel : public framework::OpKernel<T> {
37
58
bool reduce_all = ctx.Attr <bool >(" reduce_all" );
38
59
bool keep_dim = ctx.Attr <bool >(" keep_dim" );
39
60
40
- std::vector<int64_t > output_dims =
41
- CalculateOutputDims (input, output, reduce_dims, reduce_all, keep_dim);
42
-
61
+ auto output_dims =
62
+ CalculateReducedDims (input, output, reduce_dims, reduce_all, keep_dim);
43
63
auto input_dims = framework::vectorize (input->dims ());
44
64
45
65
auto & astream = platform::MKLDNNDeviceContext::tls ().get_stream ();
@@ -96,53 +116,63 @@ class ReduceMKLDNNKernel : public framework::OpKernel<T> {
96
116
paddle::framework::vectorize<int64_t >(output->dims ()))));
97
117
}
98
118
}
99
-
100
- private:
101
- std::vector<int64_t > CalculateOutputDims (const Tensor* input,
102
- const Tensor* output,
103
- std::vector<int >& reduce_dims,
104
- bool reduce_all,
105
- bool keep_dim) const {
106
- if (keep_dim) return framework::vectorize (output->dims ());
107
-
108
- if (reduce_all)
109
- return std::vector<int64_t >(framework::vectorize (input->dims ()).size (),
110
- 1 );
111
-
112
- std::vector<int64_t > output_dims (framework::vectorize (input->dims ()));
113
- for (size_t i = 0 ; i < reduce_dims.size (); ++i) {
114
- reduce_dims[i] = (reduce_dims[i] >= 0 )
115
- ? reduce_dims[i]
116
- : input->dims ().size () + reduce_dims[i];
117
- output_dims[reduce_dims[i]] = 1 ;
118
- }
119
-
120
- return output_dims;
121
- }
122
119
};
123
120
124
121
template <typename T>
125
122
class ReduceGradMKLDNNKernel : public framework ::OpKernel<T> {
126
123
public:
127
124
void RunKernel (const framework::ExecutionContext& ctx,
128
- dnnl::algorithm binary_type, float scale_x ,
129
- float scale_y) const {
125
+ dnnl::algorithm binary_type, dnnl::algorithm reduction_type ,
126
+ float scale_x, float scale_y) const {
130
127
const auto & dev_ctx =
131
128
ctx.template device_context <platform::MKLDNNDeviceContext>();
132
129
const auto & onednn_engine = dev_ctx.GetEngine ();
133
130
131
+ bool keep_dim = ctx.Attr <bool >(" keep_dim" );
132
+ bool reduce_all = ctx.Attr <bool >(" reduce_all" );
134
133
auto dims = ctx.Attr <std::vector<int >>(" dim" );
135
134
auto * input_dy = ctx.Input <Tensor>(framework::GradVarName (" Out" ));
136
135
auto * output_dx = ctx.Output <Tensor>(framework::GradVarName (" X" ));
137
136
137
+ mkldnn::memory::format_tag x_format_tag;
138
+ auto input_dims =
139
+ CalculateReducedDims (output_dx, input_dy, dims, reduce_all, keep_dim);
140
+
141
+ if (input_dims != framework::vectorize (output_dx->dims ())) {
142
+ const std::string key_pd =
143
+ platform::CreateKey (
144
+ dev_ctx, framework::vectorize (output_dx->dims ()),
145
+ ctx.InputName (" X" ),
146
+ (std::to_string (static_cast <int >(reduction_type)))) +
147
+ " @fwd_pd" ;
148
+ std::shared_ptr<dnnl::reduction::primitive_desc> fwd_pd =
149
+ std::static_pointer_cast<dnnl::reduction::primitive_desc>(
150
+ dev_ctx.GetBlob (key_pd));
151
+
152
+ PADDLE_ENFORCE_NOT_NULL (
153
+ fwd_pd, platform::errors::Unavailable (
154
+ " Forward primitive descriptor is not available in %s op, "
155
+ " cannot deduce memory format tag" ,
156
+ ctx.Type ()));
157
+
158
+ x_format_tag = platform::GetMKLDNNFormat (fwd_pd->src_desc ());
159
+
160
+ PADDLE_ENFORCE_NE (x_format_tag, mkldnn::memory::format_tag::undef,
161
+ platform::errors::InvalidArgument (
162
+ " Cannot deduce format tag for %s op" , ctx.Type ()));
163
+ } else { // fwd descriptor not available because reorder was used instead
164
+ // of reduction
165
+ x_format_tag = getPlainFormatTag (output_dx);
166
+ }
167
+
138
168
output_dx->mutable_data <T>(ctx.GetPlace ());
139
- output_dx->set_format (getPlainFormatTag (output_dx) );
169
+ output_dx->set_format (x_format_tag );
140
170
output_dx->set_layout (input_dy->layout ());
141
171
142
172
platform::BroadcastDataMKLDNNHandler<T> handler (
143
173
binary_type, dev_ctx, onednn_engine, ctx.GetPlace (), output_dx,
144
174
input_dy, scale_x, scale_y,
145
- ctx.InputName (framework::GradVarName (" Out" )));
175
+ ctx.InputName (framework::GradVarName (" Out" )), input_dims );
146
176
147
177
const auto src_dx_memory = handler.AcquireSrcMemory (output_dx);
148
178
const auto src_dy_memory = handler.AcquireSecondSrcMemory (input_dy);
0 commit comments