Skip to content

Commit c410774

Browse files
committed
Add support for dim equals 2 in activation functions
1 parent c00a5de commit c410774

File tree

2 files changed

+46
-13
lines changed

2 files changed

+46
-13
lines changed

paddle/fluid/operators/activation_mkldnn_op.cc

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,13 +40,15 @@ void eltwise_forward(const ExecContext &ctx, mkldnn::algorithm algorithm,
4040
const T *dst_data = dst->template mutable_data<T>(ctx.GetPlace());
4141

4242
// get memory dim
43-
PADDLE_ENFORCE(src->dims().size() == 4,
44-
"Input dim must be with 4, i.e. NCHW");
43+
PADDLE_ENFORCE(src->dims().size() == 2 || src->dims().size() == 4,
44+
"Input dim must be with 2 or 4");
4545
std::vector<int> src_tz = framework::vectorize2int(src->dims());
4646

4747
// create memory description
48-
// TODO(kbinias-intel): support more formats
49-
auto data_md = platform::MKLDNNMemDesc(src_tz, mkldnn::memory::f32,
48+
auto data_md = src_tz.size() == 2
49+
? platform::MKLDNNMemDesc(src_tz, mkldnn::memory::f32,
50+
mkldnn::memory::format::nc)
51+
: platform::MKLDNNMemDesc(src_tz, mkldnn::memory::f32,
5052
mkldnn::memory::format::nchw);
5153

5254
// create memory primitives
@@ -91,7 +93,10 @@ void eltwise_grad(const ExecContext &ctx, mkldnn::algorithm algorithm,
9193
std::vector<int> src_tz = framework::vectorize2int(x->dims());
9294

9395
// create memory description
94-
auto data_md = platform::MKLDNNMemDesc(src_tz, mkldnn::memory::f32,
96+
auto data_md = src_tz.size() == 2
97+
? platform::MKLDNNMemDesc(src_tz, mkldnn::memory::f32,
98+
mkldnn::memory::format::nc)
99+
: platform::MKLDNNMemDesc(src_tz, mkldnn::memory::f32,
95100
mkldnn::memory::format::nchw);
96101

97102
// create memory primitives

python/paddle/fluid/tests/unittests/test_activation_op.py

Lines changed: 36 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -535,9 +535,37 @@ def test_check_grad(self):
535535

536536

537537
#--------------------test MKLDNN--------------------
538-
class TestMKLDNNRelu(TestRelu):
538+
class TestMKLDNNReluDim2(TestRelu):
539539
def setUp(self):
540-
super(TestMKLDNNRelu, self).setUp()
540+
super(TestMKLDNNReluDim2, self).setUp()
541+
542+
self.attrs = {"use_mkldnn": True}
543+
544+
545+
class TestMKLDNNTanhDim2(TestTanh):
546+
def setUp(self):
547+
super(TestMKLDNNTanhDim2, self).setUp()
548+
549+
self.attrs = {"use_mkldnn": True}
550+
551+
552+
class TestMKLDNNSqrtDim2(TestSqrt):
553+
def setUp(self):
554+
super(TestMKLDNNSqrtDim2, self).setUp()
555+
556+
self.attrs = {"use_mkldnn": True}
557+
558+
559+
class TestMKLDNNAbsDim2(TestAbs):
560+
def setUp(self):
561+
super(TestMKLDNNAbsDim2, self).setUp()
562+
563+
self.attrs = {"use_mkldnn": True}
564+
565+
566+
class TestMKLDNNReluDim4(TestRelu):
567+
def setUp(self):
568+
super(TestMKLDNNReluDim4, self).setUp()
541569

542570
x = np.random.uniform(-1, 1, [2, 4, 3, 5]).astype("float32")
543571
# The same reason with TestAbs
@@ -549,9 +577,9 @@ def setUp(self):
549577
self.attrs = {"use_mkldnn": True}
550578

551579

552-
class TestMKLDNNTanh(TestTanh):
580+
class TestMKLDNNTanhDim4(TestTanh):
553581
def setUp(self):
554-
super(TestMKLDNNTanh, self).setUp()
582+
super(TestMKLDNNTanhDim4, self).setUp()
555583

556584
self.inputs = {
557585
'X': np.random.uniform(0.1, 1, [2, 4, 3, 5]).astype("float32")
@@ -560,9 +588,9 @@ def setUp(self):
560588
self.attrs = {"use_mkldnn": True}
561589

562590

563-
class TestMKLDNNSqrt(TestSqrt):
591+
class TestMKLDNNSqrtDim4(TestSqrt):
564592
def setUp(self):
565-
super(TestMKLDNNSqrt, self).setUp()
593+
super(TestMKLDNNSqrtDim4, self).setUp()
566594

567595
self.inputs = {
568596
'X': np.random.uniform(0.1, 1, [2, 4, 3, 5]).astype("float32")
@@ -571,9 +599,9 @@ def setUp(self):
571599
self.attrs = {"use_mkldnn": True}
572600

573601

574-
class TestMKLDNNAbs(TestAbs):
602+
class TestMKLDNNAbsDim4(TestAbs):
575603
def setUp(self):
576-
super(TestMKLDNNAbs, self).setUp()
604+
super(TestMKLDNNAbsDim4, self).setUp()
577605

578606
x = np.random.uniform(-1, 1, [2, 4, 3, 5]).astype("float32")
579607
# The same reason with TestAbs

0 commit comments

Comments
 (0)