Skip to content

Commit d3eae8f

Browse files
committed
refine relu and fix addrelu test
1 parent 4e67fe6 commit d3eae8f

File tree

3 files changed

+7
-15
lines changed

3 files changed

+7
-15
lines changed

paddle/fluid/operators/math/jit_code.cc

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -177,14 +177,6 @@ bool VActJitCode::init(int d, operand_type type) {
177177
}
178178
}
179179

180-
void VActJitCode::relu_ymm(ymm_t& ymm_dst, ymm_t& ymm_src, ymm_t& ymm_zero) {
181-
vmaxps(ymm_dst, ymm_zero, ymm_src);
182-
}
183-
184-
void VActJitCode::relu_xmm(xmm_t& xmm_dst, xmm_t& xmm_src, xmm_t& xmm_zero) {
185-
vmaxps(xmm_dst, xmm_zero, xmm_src);
186-
}
187-
188180
void VActJitCode::exp_ymm(ymm_t& ymm_dst, ymm_t& ymm_src, int fx_idx,
189181
int fy_idx, int mask_idx, int tmp_idx) {
190182
assert(ymm_src.getIdx() != ymm_dst.getIdx()); // TODO(TJ): use enfore
@@ -378,7 +370,7 @@ void VActJitCode::generate() {
378370
vmovups(ymm_src, ptr[param1 + offset]);
379371
switch (type_) {
380372
case operand_type::relu:
381-
relu_ymm(ymm_dst, ymm_src, ymm_zero);
373+
relu_jmm<ymm_t>(ymm_dst, ymm_src, ymm_zero);
382374
break;
383375
case operand_type::exp:
384376
exp_ymm(ymm_dst, ymm_src, 2, 3, 4, 5);
@@ -414,7 +406,7 @@ void VActJitCode::generate() {
414406
}
415407
switch (type_) {
416408
case operand_type::relu:
417-
relu_xmm(xmm_dst, xmm_src, xmm_zero);
409+
relu_jmm<xmm_t>(xmm_dst, xmm_src, xmm_zero);
418410
break;
419411
case operand_type::exp:
420412
exp_xmm(xmm_dst, xmm_src, 2, 3, 4, 5);

paddle/fluid/operators/math/jit_code.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -128,10 +128,10 @@ class VActJitCode : public JitCode {
128128

129129
protected:
130130
// compute relu with ymm, xmm
131-
void relu_ymm(const Xbyak::Ymm& dst, const Xbyak::Ymm& src,
132-
const Xbyak::Ymm& zero);
133-
void relu_xmm(const Xbyak::Xmm& dst, const Xbyak::Xmm& src,
134-
const Xbyak::Xmm& zero);
131+
template <typename JMM>
132+
void relu_jmm(JMM& dst, JMM& src, JMM& zero) { // NOLINT
133+
vmaxps(dst, src, zero);
134+
}
135135

136136
// compute exp with ymm, xmm
137137
void exp_ymm(const Xbyak::Ymm& dst, const Xbyak::Ymm& src, int fx_idx = 2,

paddle/fluid/operators/math/jit_kernel_test.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -762,7 +762,7 @@ TEST(JitKernel, vaddrelu) {
762762
float* zref_data = zref.data();
763763
auto trefs = GetCurrentUS();
764764
for (int i = 0; i < repeat; ++i) {
765-
vadd_ref(d, x_data, y_data, zref_data);
765+
vaddrelu_ref(d, x_data, y_data, zref_data);
766766
}
767767
auto trefe = GetCurrentUS();
768768
auto tmkls = GetCurrentUS();

0 commit comments

Comments
 (0)