@@ -23,67 +23,56 @@ namespace operators {
23
23
using paddle::framework::Tensor;
24
24
using paddle::platform::MKLDNNDeviceContext;
25
25
26
- struct MKLDNNMatrixSize final {
27
- explicit MKLDNNMatrixSize (const std::vector<int >& in,
28
- const std::vector<int >& w)
29
- : mb{in[0 ]}, ic{in[1 ]}, oc{w[1 ]}, h{in[2 ]}, w{in[3 ]} {}
30
-
31
- bool is_spatial () const { return h > 2 && w > 2 ; }
32
-
33
- const int mb;
34
- const int ic;
35
- const int oc;
36
- const int h, w;
37
- };
38
-
39
26
template <typename T>
40
27
class MKLDNNMD {
41
28
public:
42
29
explicit MKLDNNMD (const T* in, const T* w, bool bias)
43
- : sz_(std::unique_ptr<MKLDNNMatrixSize>(new MKLDNNMatrixSize(
44
- paddle::framework::vectorize2int (in->dims ()),
45
- paddle::framework::vectorize2int(w->dims ())))) {
30
+ : in{paddle::framework::vectorize2int (in->dims ())},
31
+ w{paddle::framework::vectorize2int (w->dims ())} {
46
32
with_bias_ = bias;
47
33
}
48
34
49
35
mkldnn::memory::desc dst () const {
50
- return platform::MKLDNNMemDesc ({sz_-> mb , sz_-> oc },
36
+ return platform::MKLDNNMemDesc ({in[ 0 ], w[ 1 ] },
51
37
mkldnn::memory::data_type::f32 ,
52
38
mkldnn::memory::format::nc);
53
39
}
54
40
55
41
mkldnn::memory::desc src () const {
56
- return sz_-> is_spatial ()
57
- ? platform::MKLDNNMemDesc ({sz_-> mb , sz_-> ic , sz_-> h , sz_-> w },
42
+ return is_spatial ()
43
+ ? platform::MKLDNNMemDesc ({in[ 0 ], in[ 1 ], in[ 2 ], in[ 3 ] },
58
44
mkldnn::memory::data_type::f32 ,
59
45
mkldnn::memory::format::nchw)
60
- : platform::MKLDNNMemDesc ({sz_-> mb , sz_-> ic },
46
+ : platform::MKLDNNMemDesc ({in[ 0 ], in[ 1 ] },
61
47
mkldnn::memory::data_type::f32 ,
62
48
mkldnn::memory::format::nc);
63
49
}
64
50
65
51
mkldnn::memory::desc weights () const {
66
- return sz_-> is_spatial ()
67
- ? platform::MKLDNNMemDesc ({sz_-> oc , sz_-> ic , sz_-> h , sz_-> w },
52
+ return is_spatial ()
53
+ ? platform::MKLDNNMemDesc ({w[ 1 ], in[ 1 ], in[ 2 ], in[ 3 ] },
68
54
mkldnn::memory::data_type::f32 ,
69
55
mkldnn::memory::format::oihw)
70
- : platform::MKLDNNMemDesc ({sz_-> oc , sz_-> ic },
56
+ : platform::MKLDNNMemDesc ({w[ 1 ], in[ 1 ] },
71
57
mkldnn::memory::data_type::f32 ,
72
58
mkldnn::memory::format::oi);
73
59
}
74
60
75
61
mkldnn::memory::desc bias () const {
76
62
return with_bias_
77
- ? platform::MKLDNNMemDesc ({sz_->oc },
78
- mkldnn::memory::data_type::f32 ,
63
+ ? platform::MKLDNNMemDesc ({w[1 ]}, mkldnn::memory::data_type::f32 ,
79
64
mkldnn::memory::format::format_undef)
80
65
: platform::MKLDNNMemDesc ({}, mkldnn::memory::data_type::f32 ,
81
66
mkldnn::memory::format::format_undef);
82
67
}
83
68
84
69
private:
85
- std::unique_ptr<MKLDNNMatrixSize> sz_;
70
+ bool is_spatial () const { return in.size () > 1 && w.size () > 1 ; }
71
+
72
+ std::vector<int > in;
73
+ std::vector<int > w;
86
74
bool with_bias_;
75
+ bool is_spatial_;
87
76
};
88
77
89
78
class MKLDNNMemory {
0 commit comments