Skip to content

Commit ba02ac4

Browse files
authored
use mat attr and refine test (#15448)
* use mat attr and refine test test=develop * add matmul jitcode test=develop * fix mac compile test=develop
1 parent b5ebca4 commit ba02ac4

File tree

16 files changed

+384
-76
lines changed

16 files changed

+384
-76
lines changed

paddle/fluid/operators/fused/fusion_repeated_fc_relu_op.cc

Lines changed: 24 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -79,17 +79,17 @@ void FusionRepeatedFCReluOpMaker::Make() {
7979
}
8080

8181
template <typename T>
82-
static void fc_relu(const T* x, const T* w, const T* b, T* y, int m, int n,
83-
int k) {
82+
static void fc_relu(const T* x, const T* w, const T* b, T* y,
83+
const jit::matmul_attr_t& attr) {
8484
auto matmul =
85-
jit::Get<jit::kMatMul, jit::MatMulTuples<T>, platform::CPUPlace>(k);
85+
jit::Get<jit::kMatMul, jit::MatMulTuples<T>, platform::CPUPlace>(attr);
8686
auto addbias_relu =
87-
jit::Get<jit::kVAddRelu, jit::XYZNTuples<T>, platform::CPUPlace>(n);
88-
matmul(x, w, y, m, n, k);
87+
jit::Get<jit::kVAddRelu, jit::XYZNTuples<T>, platform::CPUPlace>(attr.n);
88+
matmul(x, w, y, &attr);
8989
T* dst = y;
90-
for (int i = 0; i < m; ++i) {
91-
addbias_relu(b, dst, dst, n);
92-
dst += n;
90+
for (int i = 0; i < attr.m; ++i) {
91+
addbias_relu(b, dst, dst, attr.n);
92+
dst += attr.n;
9393
}
9494
}
9595

@@ -107,32 +107,33 @@ class FusionRepeatedFCReluKernel : public framework::OpKernel<T> {
107107

108108
auto i_dims = in->dims();
109109
auto w_dims = weights[0]->dims();
110-
int m = i_dims[0];
111-
int n = w_dims[1];
112-
int k = w_dims[0];
113-
relus[0]->Resize({m, n});
110+
jit::matmul_attr_t attr;
111+
attr.m = i_dims[0];
112+
attr.n = w_dims[1];
113+
attr.k = w_dims[0];
114+
relus[0]->Resize({attr.m, attr.n});
114115
fc_relu(in->data<T>(), weights[0]->data<T>(), biases[0]->data<T>(),
115-
relus[0]->mutable_data<T>(place), m, n, k);
116+
relus[0]->mutable_data<T>(place), attr);
116117

117118
for (int i = 1; i < weight_sz - 1; ++i) {
118119
auto i_dims = relus[i - 1]->dims();
119120
auto w_dims = weights[i]->dims();
120-
int m = i_dims[0];
121-
int n = w_dims[1];
122-
int k = w_dims[0];
123-
relus[i]->Resize({m, n});
121+
attr.m = i_dims[0];
122+
attr.n = w_dims[1];
123+
attr.k = w_dims[0];
124+
relus[i]->Resize({attr.m, attr.n});
124125
fc_relu(relus[i - 1]->data<T>(), weights[i]->data<T>(),
125-
biases[i]->data<T>(), relus[i]->mutable_data<T>(place), m, n, k);
126+
biases[i]->data<T>(), relus[i]->mutable_data<T>(place), attr);
126127
}
127128

128129
auto i_dims_last = relus[weight_sz - 2]->dims();
129130
auto w_dims_last = weights[weight_sz - 1]->dims();
130-
m = i_dims_last[0];
131-
n = w_dims_last[1];
132-
k = w_dims_last[0];
131+
attr.m = i_dims_last[0];
132+
attr.n = w_dims_last[1];
133+
attr.k = w_dims_last[0];
133134
fc_relu(relus[weight_sz - 2]->data<T>(), weights[weight_sz - 1]->data<T>(),
134-
biases[weight_sz - 1]->data<T>(), out->mutable_data<T>(place), m, n,
135-
k);
135+
biases[weight_sz - 1]->data<T>(), out->mutable_data<T>(place),
136+
attr);
136137
}
137138
};
138139

paddle/fluid/operators/fused/fusion_squared_mat_sub_op.cc

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -87,23 +87,26 @@ class FusionSquaredMatSubKernel : public framework::OpKernel<T> {
8787

8888
auto x_dims = x->dims();
8989
auto y_dims = y->dims();
90-
int m = x_dims[0];
91-
int k = x_dims[1];
92-
int n = y_dims[1];
93-
int o_numel = m * n;
90+
jit::matmul_attr_t attr;
91+
attr.m = x_dims[0];
92+
attr.k = x_dims[1];
93+
attr.n = y_dims[1];
94+
int o_numel = attr.m * attr.n;
9495

9596
auto vsquare_x =
96-
jit::Get<jit::kVSquare, jit::XYNTuples<T>, platform::CPUPlace>(m * k);
97+
jit::Get<jit::kVSquare, jit::XYNTuples<T>, platform::CPUPlace>(attr.m *
98+
attr.k);
9799
auto vsquare_y =
98-
jit::Get<jit::kVSquare, jit::XYNTuples<T>, platform::CPUPlace>(k * n);
100+
jit::Get<jit::kVSquare, jit::XYNTuples<T>, platform::CPUPlace>(attr.k *
101+
attr.n);
99102
auto vsquare_xy =
100103
jit::Get<jit::kVSquare, jit::XYNTuples<T>, platform::CPUPlace>(o_numel);
101104
auto vsub =
102105
jit::Get<jit::kVSub, jit::XYZNTuples<T>, platform::CPUPlace>(o_numel);
103106
auto vscal =
104107
jit::Get<jit::kVScal, jit::AXYNTuples<T>, platform::CPUPlace>(o_numel);
105108
auto matmul =
106-
jit::Get<jit::kMatMul, jit::MatMulTuples<T>, platform::CPUPlace>(k);
109+
jit::Get<jit::kMatMul, jit::MatMulTuples<T>, platform::CPUPlace>(attr);
107110

108111
const T* x_data = x->data<T>();
109112
const T* y_data = y->data<T>();
@@ -112,12 +115,12 @@ class FusionSquaredMatSubKernel : public framework::OpKernel<T> {
112115
T* squared_xy_data = squared_xy->mutable_data<T>(place);
113116
T* o_data = out->mutable_data<T>(place);
114117

115-
matmul(x_data, y_data, squared_xy_data, m, n, k);
118+
matmul(x_data, y_data, squared_xy_data, &attr);
116119
vsquare_xy(squared_xy_data, squared_xy_data, o_numel);
117120

118-
vsquare_x(x_data, squared_x_data, m * k);
119-
vsquare_y(y_data, squared_y_data, k * n);
120-
matmul(squared_x_data, squared_y_data, o_data, m, n, k);
121+
vsquare_x(x_data, squared_x_data, attr.m * attr.k);
122+
vsquare_y(y_data, squared_y_data, attr.k * attr.n);
123+
matmul(squared_x_data, squared_y_data, o_data, &attr);
121124

122125
vsub(squared_xy_data, o_data, o_data, o_numel);
123126
vscal(&scalar, o_data, o_data, o_numel);

paddle/fluid/operators/jit/benchmark.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -311,8 +311,9 @@ void BenchMatMulKernel() {
311311
const T* a_data = a.data<T>();
312312
const T* b_data = b.data<T>();
313313
T* c_data = c.mutable_data<T>(PlaceType());
314-
BenchAllImpls<KT, jit::MatMulTuples<T>, PlaceType>(k, a_data, b_data,
315-
c_data, m, n, k);
314+
const jit::matmul_attr_t attr{m, n, k};
315+
BenchAllImpls<KT, jit::MatMulTuples<T>, PlaceType>(attr, a_data, b_data,
316+
c_data, &attr);
316317
}
317318
}
318319
}

paddle/fluid/operators/jit/gen/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ function(USE_JITKERNEL_GEN TARGET)
99
endfunction()
1010

1111
# use gen jitcode kernel by name
12+
USE_JITKERNEL_GEN(kMatMul)
1213
USE_JITKERNEL_GEN(kVMul)
1314
USE_JITKERNEL_GEN(kVAdd)
1415
USE_JITKERNEL_GEN(kVSub)
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
*
3+
* Licensed under the Apache License, Version 2.0 (the "License");
4+
* you may not use this file except in compliance with the License.
5+
* You may obtain a copy of the License at
6+
*
7+
* http://www.apache.org/licenses/LICENSE-2.0
8+
*
9+
* Unless required by applicable law or agreed to in writing, software
10+
* distributed under the License is distributed on an "AS IS" BASIS,
11+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
* See the License for the specific language governing permissions and
13+
* limitations under the License. */
14+
15+
#include "paddle/fluid/operators/jit/gen/matmul.h"
16+
#include <stddef.h> // offsetof
17+
#include <vector>
18+
19+
#include "paddle/fluid/operators/jit/registry.h"
20+
#include "paddle/fluid/platform/cpu_info.h"
21+
22+
namespace paddle {
23+
namespace operators {
24+
namespace jit {
25+
namespace gen {
26+
27+
void MatMulJitCode::genCode() {
28+
preCode();
29+
int block, rest;
30+
const auto groups = packed_groups(n_, k_, &block, &rest);
31+
PADDLE_ENFORCE_GT(groups.front(), 0);
32+
33+
const int block_len = sizeof(float) * block;
34+
const int x_reg_idx = (block == ZMM_FLOAT_BLOCK ? 32 : 16) - 1;
35+
const int w_reg_idx = x_reg_idx - 1;
36+
// from packed mov(reg_ptr_wgt, ptr[param_attr + offsetof(matmul_attr_t,
37+
// packed_weight)]);
38+
mov(reg_ptr_wgt, param_y);
39+
size_t z_offset = 0;
40+
size_t wgt_offset = 0;
41+
for (size_t g = 0; g < groups.size(); ++g) {
42+
size_t x_offset = 0;
43+
for (int k = 0; k < k_; ++k) {
44+
vbroadcastss(zmm_t(x_reg_idx), ptr[param_x + x_offset]);
45+
// clean
46+
if (k == 0) {
47+
for (int i = 0; i < groups[g]; ++i) {
48+
vxorps(zmm_t(i), zmm_t(i), zmm_t(i));
49+
}
50+
}
51+
for (int i = 0; i < groups[g]; ++i) {
52+
vmovups(zmm_t(w_reg_idx), ptr[reg_ptr_wgt + wgt_offset]);
53+
vfmadd231ps(zmm_t(i), zmm_t(w_reg_idx), zmm_t(x_reg_idx));
54+
wgt_offset += block_len;
55+
}
56+
// last one, save
57+
if (k == k_ - 1) {
58+
for (int i = 0; i < groups[g]; ++i) {
59+
// only rest save should be careful
60+
if (rest != 0 && g == groups.size() - 1 && i == groups[g] - 1) {
61+
break;
62+
}
63+
vmovups(ptr[param_z + z_offset + i * block_len], zmm_t(i));
64+
}
65+
}
66+
x_offset += sizeof(float);
67+
}
68+
z_offset += block_len * groups[g];
69+
}
70+
71+
if (rest != 0) {
72+
// below should refine with mask
73+
int reg_idx = groups.back() - 1;
74+
z_offset = (n_ - rest) * sizeof(float);
75+
int inner_block = 8;
76+
while (rest > 0) {
77+
if (rest >= 8) {
78+
inner_block = 8;
79+
vmovups(ptr[param_z + z_offset], ymm_t(reg_idx));
80+
// shift zmm of inner_block, change reg_idx if update
81+
} else if (rest >= 4) {
82+
inner_block = 4;
83+
vmovups(ptr[param_z + z_offset], xmm_t(reg_idx));
84+
} else if (rest >= 2) {
85+
inner_block = 2;
86+
vmovq(ptr[param_z + z_offset], xmm_t(reg_idx));
87+
} else {
88+
inner_block = 1;
89+
vmovss(ptr[param_z + z_offset], xmm_t(reg_idx));
90+
}
91+
z_offset += inner_block * sizeof(float);
92+
rest -= inner_block;
93+
}
94+
}
95+
96+
postCode();
97+
}
98+
99+
class MatMulCreator : public JitCodeCreator<matmul_attr_t> {
100+
public:
101+
bool UseMe(const matmul_attr_t& attr) const override {
102+
return attr.m == 1 && platform::MayIUse(platform::avx512f) &&
103+
attr.n % ZMM_FLOAT_BLOCK == 0 && attr.k < 512;
104+
}
105+
size_t CodeSize(const matmul_attr_t& attr) const override {
106+
int block = YMM_FLOAT_BLOCK;
107+
if (platform::MayIUse(platform::avx512f)) {
108+
block = ZMM_FLOAT_BLOCK;
109+
}
110+
return 96 + 4 * attr.k * (attr.n / block + 1) * 8;
111+
}
112+
std::unique_ptr<GenBase> CreateJitCode(
113+
const matmul_attr_t& attr) const override {
114+
PADDLE_ENFORCE_GT(attr.m, 0);
115+
PADDLE_ENFORCE_GT(attr.n, 0);
116+
PADDLE_ENFORCE_GT(attr.k, 0);
117+
return make_unique<MatMulJitCode>(attr, CodeSize(attr));
118+
}
119+
};
120+
121+
} // namespace gen
122+
} // namespace jit
123+
} // namespace operators
124+
} // namespace paddle
125+
126+
namespace gen = paddle::operators::jit::gen;
127+
128+
REGISTER_JITKERNEL_GEN(kMatMul, gen::MatMulCreator);
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
*
3+
* Licensed under the Apache License, Version 2.0 (the "License");
4+
* you may not use this file except in compliance with the License.
5+
* You may obtain a copy of the License at
6+
*
7+
* http://www.apache.org/licenses/LICENSE-2.0
8+
*
9+
* Unless required by applicable law or agreed to in writing, software
10+
* distributed under the License is distributed on an "AS IS" BASIS,
11+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
* See the License for the specific language governing permissions and
13+
* limitations under the License. */
14+
15+
#pragma once
16+
17+
#include <stdlib.h> // for malloc and free
18+
#include <string>
19+
#include <vector>
20+
#include "glog/logging.h"
21+
#include "paddle/fluid/operators/jit/gen/jitcode.h"
22+
#include "paddle/fluid/platform/enforce.h"
23+
24+
namespace paddle {
25+
namespace operators {
26+
namespace jit {
27+
namespace gen {
28+
29+
class MatMulJitCode : public JitCode {
30+
public:
31+
explicit MatMulJitCode(const matmul_attr_t& attr,
32+
size_t code_size = 256 * 1024,
33+
void* code_ptr = nullptr)
34+
: JitCode(code_size, code_ptr), m_(attr.m), n_(attr.n), k_(attr.k) {
35+
PADDLE_ENFORCE_EQ(m_, 1, "Only support m==1 yet");
36+
this->genCode();
37+
}
38+
39+
virtual const char* name() const {
40+
std::string base = "MatMulJitCode";
41+
base = base + "_M" + std::to_string(m_) + "_N" + std::to_string(n_) + "_K" +
42+
std::to_string(k_);
43+
return base.c_str();
44+
}
45+
void genCode() override;
46+
47+
private:
48+
int m_, n_, k_;
49+
50+
reg64_t param_x{abi_param1};
51+
reg64_t param_y{abi_param2};
52+
reg64_t param_z{abi_param3};
53+
reg64_t param_attr{abi_param4};
54+
reg64_t reg_tmp{rax};
55+
56+
reg64_t reg_ptr_wgt{r10};
57+
};
58+
59+
} // namespace gen
60+
} // namespace jit
61+
} // namespace operators
62+
} // namespace paddle

paddle/fluid/operators/jit/gen_base.cc

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
#include <fstream>
1717
#include <iostream>
1818
#include <sstream>
19+
#include <vector>
20+
#include "paddle/fluid/platform/cpu_info.h"
1921

2022
DEFINE_bool(dump_jitcode, false, "Whether to dump the jitcode to file");
2123

@@ -38,6 +40,35 @@ void GenBase::dumpCode(const unsigned char* code) const {
3840
}
3941
}
4042

43+
std::vector<int> packed_groups(int n, int k, int* block_out, int* rest_out) {
44+
int block;
45+
int max_num_regs;
46+
if (platform::MayIUse(platform::avx512f)) {
47+
block = ZMM_FLOAT_BLOCK;
48+
max_num_regs = 32;
49+
} else {
50+
block = YMM_FLOAT_BLOCK;
51+
max_num_regs = 16;
52+
}
53+
// one for x, one for y, others for z
54+
const int max_used_regs_for_n = max_num_regs - 2;
55+
const int aligned_n = n % block == 0 ? n : (n / block + 1) * block;
56+
const int num_block = aligned_n / block;
57+
const int num_groups = num_block / max_used_regs_for_n;
58+
std::vector<int> groups(num_groups, max_used_regs_for_n);
59+
int rest_num_regs = num_block % max_used_regs_for_n;
60+
if (rest_num_regs != 0) {
61+
groups.push_back(rest_num_regs);
62+
}
63+
if (block_out) {
64+
*block_out = block;
65+
}
66+
if (rest_out) {
67+
*rest_out = n % block;
68+
}
69+
return groups;
70+
}
71+
4172
} // namespace jit
4273
} // namespace operators
4374
} // namespace paddle

0 commit comments

Comments
 (0)