Skip to content

Commit ad86739

Browse files
authored
fix repeated_fc_fuse_pass and jit::matmul bug test=develop test=release/1.6 (#20948)
- fix jit::matmul bug input x, shape(m, k), weight, shape(k, n)
1 parent ca92f5c commit ad86739

File tree

2 files changed

+8
-2
lines changed

2 files changed

+8
-2
lines changed

paddle/fluid/operators/jit/gen/matmul.cc

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,12 @@ void MatMulJitCode::genCode() {
4040
size_t wgt_offset = 0;
4141
for (size_t g = 0; g < groups.size(); ++g) {
4242
size_t x_offset = 0;
43+
size_t wgt_offset_tmp = 0;
44+
for (int i = 0; i < g; ++i) {
45+
wgt_offset_tmp += groups[i] * block_len;
46+
}
4347
for (int k = 0; k < k_; ++k) {
48+
wgt_offset = wgt_offset_tmp;
4449
vbroadcastss(zmm_t(x_reg_idx), ptr[param_x + x_offset]);
4550
// clean
4651
if (k == 0) {
@@ -49,7 +54,8 @@ void MatMulJitCode::genCode() {
4954
}
5055
}
5156
for (int i = 0; i < groups[g]; ++i) {
52-
vmovups(zmm_t(w_reg_idx), ptr[reg_ptr_wgt + wgt_offset]);
57+
vmovups(zmm_t(w_reg_idx),
58+
ptr[reg_ptr_wgt + wgt_offset + k * n_ * sizeof(float)]);
5359
vfmadd231ps(zmm_t(i), zmm_t(w_reg_idx), zmm_t(x_reg_idx));
5460
wgt_offset += block_len;
5561
}

python/paddle/fluid/tests/unittests/test_fusion_repeated_fc_relu_op.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def set_conf(self):
7878
class TestFusionRepeatedFCReluOpBS1(TestFusionRepeatedFCReluOp):
7979
def set_conf(self):
8080
self.bs = 1
81-
self.oc = [4, 2, 7, 5]
81+
self.oc = [4, 2, 7, 5, 512, 1024]
8282

8383

8484
if __name__ == '__main__':

0 commit comments

Comments
 (0)