Skip to content

Commit 5094568

Browse files
committed
add hmax, hsum jitcode
test=develop
1 parent 8117725 commit 5094568

File tree

5 files changed

+200
-1
lines changed

5 files changed

+200
-1
lines changed

paddle/fluid/operators/jit/gen/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,3 +28,5 @@ USE_JITKERNEL_GEN(kGRUHtPart1)
2828
USE_JITKERNEL_GEN(kGRUHtPart2)
2929
USE_JITKERNEL_GEN(kNCHW16CMulNC)
3030
USE_JITKERNEL_GEN(kSeqPool)
31+
USE_JITKERNEL_GEN(kHMax)
32+
USE_JITKERNEL_GEN(kHSum)
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
*
3+
* Licensed under the Apache License, Version 2.0 (the "License");
4+
* you may not use this file except in compliance with the License.
5+
* You may obtain a copy of the License at
6+
*
7+
* http://www.apache.org/licenses/LICENSE-2.0
8+
*
9+
* Unless required by applicable law or agreed to in writing, software
10+
* distributed under the License is distributed on an "AS IS" BASIS,
11+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
* See the License for the specific language governing permissions and
13+
* limitations under the License. */
14+
15+
#include "paddle/fluid/operators/jit/gen/hopv.h"
16+
#include "paddle/fluid/operators/jit/registry.h"
17+
#include "paddle/fluid/platform/cpu_info.h"
18+
19+
namespace paddle {
20+
namespace operators {
21+
namespace jit {
22+
namespace gen {
23+
24+
void HOPVJitCode::genCode() {
25+
const int num_blocks = num_ / YMM_FLOAT_BLOCK;
26+
int offset = 0;
27+
28+
if (num_blocks > 0) {
29+
// load one firstly
30+
vmovups(ymm_tmp, ptr[param_src]);
31+
offset += sizeof(float) * YMM_FLOAT_BLOCK;
32+
for (int i = 1; i < num_blocks; ++i) {
33+
vmovups(ymm_src, ptr[param_src + offset]);
34+
process(ymm_tmp, ymm_src, ymm_tmp);
35+
offset += sizeof(float) * YMM_FLOAT_BLOCK;
36+
}
37+
vextractf128(xmm_dst, ymm_tmp, 1);
38+
process(xmm_dst, xmm_dst, xmm_tmp);
39+
} else {
40+
if (type_ == operand_type::MAX) {
41+
vbroadcastss(ymm_dst, ptr[param_src]);
42+
} else if (type_ == operand_type::ADD) {
43+
vxorps(ymm_dst, ymm_dst, ymm_dst);
44+
}
45+
}
46+
47+
int rest = num_ % YMM_FLOAT_BLOCK;
48+
if (rest >= 4) {
49+
vmovups(xmm_src, ptr[param_src + offset]);
50+
offset += sizeof(float) * 4;
51+
rest -= 4;
52+
process(xmm_dst, xmm_dst, xmm_src);
53+
}
54+
55+
vpermilps(xmm_tmp, xmm_dst, 16 + 8 + 3);
56+
process(xmm_dst, xmm_dst, xmm_tmp);
57+
58+
if (rest >= 2) {
59+
vmovq(xmm_src, ptr[param_src + offset]);
60+
offset += sizeof(float) * 2;
61+
rest -= 2;
62+
process(xmm_dst, xmm_dst, xmm_src);
63+
}
64+
65+
vpermilps(xmm_tmp, xmm_dst, 1);
66+
process(xmm_dst, xmm_dst, xmm_tmp);
67+
68+
if (rest >= 1) {
69+
vmovss(xmm_src, ptr[param_src + offset]);
70+
process(xmm_dst, xmm_dst, xmm_src);
71+
}
72+
vmovss(ptr[param_dst], xmm_dst);
73+
ret();
74+
}
75+
76+
#define DECLARE_HOP_CREATOR(name) \
77+
class name##Creator : public JitCodeCreator<int> { \
78+
public: \
79+
bool UseMe(const int& attr) const override { \
80+
return platform::MayIUse(platform::avx); \
81+
} \
82+
size_t CodeSize(const int& d) const override { \
83+
return 96 + d / YMM_FLOAT_BLOCK * 4 * 8; \
84+
} \
85+
std::unique_ptr<GenBase> CreateJitCode(const int& attr) const override { \
86+
return make_unique<name##JitCode>(attr, CodeSize(attr)); \
87+
} \
88+
}
89+
90+
DECLARE_HOP_CREATOR(HMax);
91+
DECLARE_HOP_CREATOR(HSum);
92+
93+
#undef DECLARE_HOP_CREATOR
94+
95+
} // namespace gen
96+
} // namespace jit
97+
} // namespace operators
98+
} // namespace paddle
99+
100+
namespace gen = paddle::operators::jit::gen;
101+
102+
REGISTER_JITKERNEL_GEN(kHMax, gen::HMaxCreator);
103+
REGISTER_JITKERNEL_GEN(kHSum, gen::HSumCreator);

paddle/fluid/operators/jit/gen/hopv.h

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
*
3+
* Licensed under the Apache License, Version 2.0 (the "License");
4+
* you may not use this file except in compliance with the License.
5+
* You may obtain a copy of the License at
6+
*
7+
* http://www.apache.org/licenses/LICENSE-2.0
8+
*
9+
* Unless required by applicable law or agreed to in writing, software
10+
* distributed under the License is distributed on an "AS IS" BASIS,
11+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
* See the License for the specific language governing permissions and
13+
* limitations under the License. */
14+
15+
#pragma once
16+
17+
#include <string>
18+
#include "glog/logging.h"
19+
#include "paddle/fluid/operators/jit/gen/jitcode.h"
20+
21+
namespace paddle {
22+
namespace operators {
23+
namespace jit {
24+
namespace gen {
25+
26+
// horizontal operand vector
27+
class HOPVJitCode : public JitCode {
28+
public:
29+
explicit HOPVJitCode(int d, operand_type type, size_t code_size = 256 * 1024,
30+
void* code_ptr = nullptr)
31+
: JitCode(code_size, code_ptr), num_(d), type_(type) {
32+
if (!(type_ == operand_type::MAX || type_ == operand_type::ADD)) {
33+
LOG(FATAL) << "Do not support this operand type: " << type_;
34+
}
35+
this->genCode();
36+
}
37+
38+
virtual const char* name() const {
39+
std::string base = "VXXJitCode";
40+
if (type_ == operand_type::MAX) {
41+
base += "_MAX";
42+
} else {
43+
base += "_SUM";
44+
}
45+
return base.c_str();
46+
}
47+
void genCode() override;
48+
49+
protected:
50+
template <typename JMM>
51+
void process(JMM& dst, JMM& src1, JMM& src2) { // NOLINT
52+
if (type_ == operand_type::MAX) {
53+
vmaxps(dst, src1, src2);
54+
} else if (type_ == operand_type::ADD) {
55+
vaddps(dst, src1, src2);
56+
}
57+
}
58+
59+
private:
60+
int num_;
61+
operand_type type_;
62+
reg64_t param_src{abi_param1};
63+
reg64_t param_dst{abi_param2};
64+
reg64_t param_attr{abi_param3};
65+
66+
ymm_t ymm_tmp = ymm_t(0);
67+
ymm_t ymm_src = ymm_t(1);
68+
ymm_t ymm_dst = ymm_t(2);
69+
70+
xmm_t xmm_tmp = xmm_t(0);
71+
xmm_t xmm_src = xmm_t(1);
72+
xmm_t xmm_dst = xmm_t(2);
73+
};
74+
75+
#define DECLARE_HOP_JITCODE(name, op_type) \
76+
class name##JitCode : public HOPVJitCode { \
77+
public: \
78+
explicit name##JitCode(int d, size_t code_size, void* code_ptr = nullptr) \
79+
: HOPVJitCode(d, op_type, code_size, code_ptr) {} \
80+
};
81+
82+
DECLARE_HOP_JITCODE(HMax, operand_type::MAX);
83+
DECLARE_HOP_JITCODE(HSum, operand_type::ADD);
84+
85+
#undef DECLARE_HOP_JITCODE
86+
87+
} // namespace gen
88+
} // namespace jit
89+
} // namespace operators
90+
} // namespace paddle

paddle/fluid/operators/jit/gen/jitcode.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ using Label = Xbyak::Label;
4747

4848
typedef enum {
4949
MUL = 0,
50+
MAX,
5051
ADD,
5152
SUB,
5253
RELU,

paddle/fluid/operators/jit/test.cc

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -383,16 +383,19 @@ void TestAXYNKernel() {
383383
template <jit::KernelType KT, typename T, typename PlaceType>
384384
void TestXRNKernel() {
385385
VLOG(10) << "===== Test JITKernel " << jit::to_string(KT);
386+
auto last_acc = acc;
387+
acc = 1e-4;
386388
for (int d : TestSizes()) {
387389
auto ref = jit::GetRefer<KT, jit::XRNTuples<T>>();
388390
EXPECT_TRUE(ref != nullptr);
389391
std::vector<T> x(d);
390-
RandomVec<T>(d, x.data());
392+
RandomVec<T>(d, x.data(), -2.f, 2.f);
391393
T ref_res;
392394
ref(x.data(), &ref_res, d);
393395
TestAllImpls<KT, jit::XRNTuples<T>, PlaceType, std::vector<T>, T>(d, x,
394396
ref_res);
395397
}
398+
acc = last_acc;
396399
}
397400

398401
template <jit::KernelType KT, typename T, typename PlaceType>

0 commit comments

Comments
 (0)