Skip to content

Commit d96ac8e

Browse files
committed
feat: Added test fixtures for faster test writing
1 parent cd1c1ee commit d96ac8e

26 files changed

+1206
-266
lines changed

CMakeLists.txt

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,8 @@ set(KERNEL_FILES
6565
matmul_16_6_1.h
6666
matmul_16_6_k.cpp
6767
matmul_16_6_k.h
68+
matmul_16m_4n_k.h
69+
matmul_16m_4n_k.cpp
6870
)
6971

7072
set(ARM_INSTRUCTION_FILES
@@ -84,6 +86,7 @@ set(ARM_INSTRUCTION_FILES
8486
base/sub.h
8587
base/mov.h
8688
base/orr.h
89+
base/madd.h
8790

8891
simd_fp/ld1.h
8992
simd_fp/st1.h
@@ -92,10 +95,17 @@ set(ARM_INSTRUCTION_FILES
9295
simd_fp/stp.h
9396
)
9497

98+
set(TEST_FILES
99+
BaseGeneration.test.h
100+
BaseGeneration.test.cpp
101+
Brgemm.test.cpp
102+
)
103+
95104
set(TEST_KERNELS
96105
matmul.test.h
97106
matmul_16_6_1.test.cpp
98107
matmul_16_6_k.test.cpp
108+
matmul_16m_4n_k.test.cpp
99109
)
100110

101111
set(TEST_ARM_INSTRUCTION_FILES
@@ -109,6 +119,7 @@ set(TEST_ARM_INSTRUCTION_FILES
109119
base/orr.test.cpp
110120
base/mov.test.cpp
111121
base/movz.test.cpp
122+
base/madd.test.cpp
112123

113124
simd_fp/fmla.test.cpp
114125
simd_fp/ld1.test.cpp
@@ -129,11 +140,14 @@ foreach(file ${KERNEL_FILES})
129140
list(APPEND SOURCE_FILEPATHS src/main/kernels/${file})
130141
endforeach()
131142

132-
133143
foreach(file ${ARM_INSTRUCTION_FILES})
134144
list(APPEND SOURCE_FILEPATHS src/main/arm_instructions/${file})
135145
endforeach()
136146

147+
foreach(file ${TEST_FILES})
148+
list(APPEND TEST_FILEPATHS src/test/${file})
149+
endforeach()
150+
137151
foreach(file ${TEST_KERNELS})
138152
list(APPEND TEST_FILEPATHS src/test/kernels/${file})
139153
endforeach()

gen_correct.dis

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
2+
build/matmul_16m_6n_k.bin: file format binary
3+
4+
5+
Disassembly of section .data:
6+
7+
0000000000000000 <.data>:
8+
0: 6dbf27e8 stp d8, d9, [sp, #-16]!
9+
4: d37ef463 lsl x3, x3, #2
10+
8: d37ef484 lsl x4, x4, #2
11+
c: d37ef4a5 lsl x5, x5, #2
12+
10: aa0103e6 mov x6, x1
13+
14: aa0203e7 mov x7, x2
14+
18: aa0003e8 mov x8, x0
15+
1c: aa0103e9 mov x9, x1
16+
20: aa0003ea mov x10, x0
17+
24: aa0203eb mov x11, x2
18+
28: d280008c mov x12, #0x4 // #4
19+
2c: d2800211 mov x17, #0x10 // #16
20+
30: d1000631 sub x17, x17, #0x1
21+
34: aa0a03e8 mov x8, x10
22+
38: aa0b03e7 mov x7, x11
23+
3c: d2800090 mov x16, #0x4 // #4
24+
40: d1000610 sub x16, x16, #0x1
25+
44: aa0703e2 mov x2, x7
26+
48: aa0803e0 mov x0, x8
27+
4c: aa0903e6 mov x6, x9
28+
50: aa0903e1 mov x1, x9
29+
54: 4cc52859 ld1 {v25.4s-v28.4s}, [x2], x5
30+
58: 4cc52851 ld1 {v17.4s-v20.4s}, [x2], x5
31+
5c: 4cc52855 ld1 {v21.4s-v24.4s}, [x2], x5
32+
60: 4cc52845 ld1 {v5.4s-v8.4s}, [x2], x5
33+
64: d280002f mov x15, #0x1 // #1
34+
68: d10005ef sub x15, x15, #0x1
35+
6c: 4cc32800 ld1 {v0.4s-v3.4s}, [x0], x3
36+
70: bd400024 ldr s4, [x1]
37+
74: 8b040021 add x1, x1, x4
38+
78: 4f841019 fmla v25.4s, v0.4s, v4.s[0]
39+
7c: 4f84103a fmla v26.4s, v1.4s, v4.s[0]
40+
80: 4f84105b fmla v27.4s, v2.4s, v4.s[0]
41+
84: 4f84107c fmla v28.4s, v3.4s, v4.s[0]
42+
88: bd400024 ldr s4, [x1]
43+
8c: 8b040021 add x1, x1, x4
44+
90: 4f841011 fmla v17.4s, v0.4s, v4.s[0]
45+
94: 4f841032 fmla v18.4s, v1.4s, v4.s[0]
46+
98: 4f841053 fmla v19.4s, v2.4s, v4.s[0]
47+
9c: 4f841074 fmla v20.4s, v3.4s, v4.s[0]
48+
a0: bd400024 ldr s4, [x1]
49+
a4: 8b040021 add x1, x1, x4
50+
a8: 4f841015 fmla v21.4s, v0.4s, v4.s[0]
51+
ac: 4f841036 fmla v22.4s, v1.4s, v4.s[0]
52+
b0: 4f841057 fmla v23.4s, v2.4s, v4.s[0]
53+
b4: 4f841078 fmla v24.4s, v3.4s, v4.s[0]
54+
b8: bd400024 ldr s4, [x1]
55+
bc: 8b040021 add x1, x1, x4
56+
c0: 4f841005 fmla v5.4s, v0.4s, v4.s[0]
57+
c4: 4f841026 fmla v6.4s, v1.4s, v4.s[0]
58+
c8: 4f841047 fmla v7.4s, v2.4s, v4.s[0]
59+
cc: 4f841068 fmla v8.4s, v3.4s, v4.s[0]
60+
d0: 910010c6 add x6, x6, #0x4
61+
d4: aa0603e1 mov x1, x6
62+
d8: b5fffc8f cbnz x15, 0x68
63+
dc: aa0703e2 mov x2, x7
64+
e0: 4c852859 st1 {v25.4s-v28.4s}, [x2], x5
65+
e4: 4c852851 st1 {v17.4s-v20.4s}, [x2], x5
66+
e8: 4c852855 st1 {v21.4s-v24.4s}, [x2], x5
67+
ec: 4c852845 st1 {v5.4s-v8.4s}, [x2], x5
68+
f0: 910100e7 add x7, x7, #0x40
69+
f4: 91010108 add x8, x8, #0x40
70+
f8: b5fffa50 cbnz x16, 0x40
71+
fc: 9b0c2489 madd x9, x4, x12, x9
72+
100: 9b0c2cab madd x11, x5, x12, x11
73+
104: b5fff971 cbnz x17, 0x30
74+
108: 6cc127e8 ldp d8, d9, [sp], #16
75+
10c: d65f03c0 ret

gen_false.dis

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
2+
build/matmul_16m_6n_k.bin: file format binary
3+
4+
5+
Disassembly of section .data:
6+
7+
0000000000000000 <.data>:
8+
0: d37ef463 lsl x3, x3, #2
9+
4: d37ef484 lsl x4, x4, #2
10+
8: d37ef4a5 lsl x5, x5, #2
11+
c: 4c402800 ld1 {v0.4s-v3.4s}, [x0]
12+
10: bd400024 ldr s4, [x1]
13+
14: 8b040021 add x1, x1, x4
14+
18: 4c402859 ld1 {v25.4s-v28.4s}, [x2]
15+
1c: 4f841019 fmla v25.4s, v0.4s, v4.s[0]
16+
20: 4f84103a fmla v26.4s, v1.4s, v4.s[0]
17+
24: 4f84105b fmla v27.4s, v2.4s, v4.s[0]
18+
28: 4f84107c fmla v28.4s, v3.4s, v4.s[0]
19+
2c: 4c852859 st1 {v25.4s-v28.4s}, [x2], x5
20+
30: bd400024 ldr s4, [x1]
21+
34: 8b040021 add x1, x1, x4
22+
38: 4c402851 ld1 {v17.4s-v20.4s}, [x2]
23+
3c: 4f841011 fmla v17.4s, v0.4s, v4.s[0]
24+
40: 4f841032 fmla v18.4s, v1.4s, v4.s[0]
25+
44: 4f841053 fmla v19.4s, v2.4s, v4.s[0]
26+
48: 4f841074 fmla v20.4s, v3.4s, v4.s[0]
27+
4c: 4c852851 st1 {v17.4s-v20.4s}, [x2], x5
28+
50: bd400024 ldr s4, [x1]
29+
54: 8b040021 add x1, x1, x4
30+
58: 4c402855 ld1 {v21.4s-v24.4s}, [x2]
31+
5c: 4f841015 fmla v21.4s, v0.4s, v4.s[0]
32+
60: 4f841036 fmla v22.4s, v1.4s, v4.s[0]
33+
64: 4f841057 fmla v23.4s, v2.4s, v4.s[0]
34+
68: 4f841078 fmla v24.4s, v3.4s, v4.s[0]
35+
6c: 4c852855 st1 {v21.4s-v24.4s}, [x2], x5
36+
70: bd400024 ldr s4, [x1]
37+
74: 8b040021 add x1, x1, x4
38+
78: 4c402859 ld1 {v25.4s-v28.4s}, [x2]
39+
7c: 4f841019 fmla v25.4s, v0.4s, v4.s[0]
40+
80: 4f84103a fmla v26.4s, v1.4s, v4.s[0]
41+
84: 4f84105b fmla v27.4s, v2.4s, v4.s[0]
42+
88: 4f84107c fmla v28.4s, v3.4s, v4.s[0]
43+
8c: 4c852859 st1 {v25.4s-v28.4s}, [x2], x5
44+
90: bd400024 ldr s4, [x1]
45+
94: 8b040021 add x1, x1, x4
46+
98: 4c402851 ld1 {v17.4s-v20.4s}, [x2]
47+
9c: 4f841011 fmla v17.4s, v0.4s, v4.s[0]
48+
a0: 4f841032 fmla v18.4s, v1.4s, v4.s[0]
49+
a4: 4f841053 fmla v19.4s, v2.4s, v4.s[0]
50+
a8: 4f841074 fmla v20.4s, v3.4s, v4.s[0]
51+
ac: 4c852851 st1 {v17.4s-v20.4s}, [x2], x5
52+
b0: bd400024 ldr s4, [x1]
53+
b4: 8b040021 add x1, x1, x4
54+
b8: 4c402855 ld1 {v21.4s-v24.4s}, [x2]
55+
bc: 4f841015 fmla v21.4s, v0.4s, v4.s[0]
56+
c0: 4f841036 fmla v22.4s, v1.4s, v4.s[0]
57+
c4: 4f841057 fmla v23.4s, v2.4s, v4.s[0]
58+
c8: 4f841078 fmla v24.4s, v3.4s, v4.s[0]
59+
cc: 4c852855 st1 {v21.4s-v24.4s}, [x2], x5
60+
d0: d65f03c0 ret
61+
d4: 6dbf27e8 stp d8, d9, [sp, #-16]!
62+
d8: d37ef463 lsl x3, x3, #2
63+
dc: d37ef484 lsl x4, x4, #2
64+
e0: d37ef4a5 lsl x5, x5, #2
65+
e4: aa0103e6 mov x6, x1
66+
e8: aa0203e7 mov x7, x2
67+
ec: aa0003e8 mov x8, x0
68+
f0: aa0103e9 mov x9, x1
69+
f4: aa0003ea mov x10, x0
70+
f8: aa0203eb mov x11, x2
71+
fc: d280008c mov x12, #0x4 // #4
72+
100: d2800211 mov x17, #0x10 // #16
73+
104: d1000631 sub x17, x17, #0x1
74+
108: aa0a03e8 mov x8, x10
75+
10c: aa0b03e7 mov x7, x11
76+
110: d2800090 mov x16, #0x4 // #4
77+
114: d1000610 sub x16, x16, #0x1
78+
118: aa0703e2 mov x2, x7
79+
11c: aa0803e0 mov x0, x8
80+
120: aa0903e6 mov x6, x9
81+
124: aa0903e1 mov x1, x9
82+
128: 4cc52859 ld1 {v25.4s-v28.4s}, [x2], x5
83+
12c: 4cc52851 ld1 {v17.4s-v20.4s}, [x2], x5
84+
130: 4cc52855 ld1 {v21.4s-v24.4s}, [x2], x5
85+
134: 4cc52845 ld1 {v5.4s-v8.4s}, [x2], x5
86+
138: d280002f mov x15, #0x1 // #1
87+
13c: d10005ef sub x15, x15, #0x1
88+
140: 4cc32800 ld1 {v0.4s-v3.4s}, [x0], x3
89+
144: bd400024 ldr s4, [x1]
90+
148: 8b040021 add x1, x1, x4
91+
14c: 4f841019 fmla v25.4s, v0.4s, v4.s[0]
92+
150: 4f84103a fmla v26.4s, v1.4s, v4.s[0]
93+
154: 4f84105b fmla v27.4s, v2.4s, v4.s[0]
94+
158: 4f84107c fmla v28.4s, v3.4s, v4.s[0]
95+
15c: bd400024 ldr s4, [x1]
96+
160: 8b040021 add x1, x1, x4
97+
164: 4f841011 fmla v17.4s, v0.4s, v4.s[0]
98+
168: 4f841032 fmla v18.4s, v1.4s, v4.s[0]
99+
16c: 4f841053 fmla v19.4s, v2.4s, v4.s[0]
100+
170: 4f841074 fmla v20.4s, v3.4s, v4.s[0]
101+
174: bd400024 ldr s4, [x1]
102+
178: 8b040021 add x1, x1, x4
103+
17c: 4f841015 fmla v21.4s, v0.4s, v4.s[0]
104+
180: 4f841036 fmla v22.4s, v1.4s, v4.s[0]
105+
184: 4f841057 fmla v23.4s, v2.4s, v4.s[0]
106+
188: 4f841078 fmla v24.4s, v3.4s, v4.s[0]
107+
18c: bd400024 ldr s4, [x1]
108+
190: 8b040021 add x1, x1, x4
109+
194: 4f841005 fmla v5.4s, v0.4s, v4.s[0]
110+
198: 4f841026 fmla v6.4s, v1.4s, v4.s[0]
111+
19c: 4f841047 fmla v7.4s, v2.4s, v4.s[0]
112+
1a0: 4f841068 fmla v8.4s, v3.4s, v4.s[0]
113+
1a4: 910010c6 add x6, x6, #0x4
114+
1a8: aa0603e1 mov x1, x6
115+
1ac: b5fffc8f cbnz x15, 0x13c
116+
1b0: aa0703e2 mov x2, x7
117+
1b4: 4c852859 st1 {v25.4s-v28.4s}, [x2], x5
118+
1b8: 4c852851 st1 {v17.4s-v20.4s}, [x2], x5
119+
1bc: 4c852855 st1 {v21.4s-v24.4s}, [x2], x5
120+
1c0: 4c852845 st1 {v5.4s-v8.4s}, [x2], x5
121+
1c4: 910100e7 add x7, x7, #0x40
122+
1c8: 91010108 add x8, x8, #0x40
123+
1cc: b5fffa50 cbnz x16, 0x114
124+
1d0: 9b0c2489 madd x9, x4, x12, x9
125+
1d4: 9b0c2cab madd x11, x5, x12, x11
126+
1d8: b5fff971 cbnz x17, 0x104
127+
1dc: 6cc127e8 ldp d8, d9, [sp], #16
128+
1e0: d65f03c0 ret

src/main/Brgemm.cpp

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include "Brgemm.h"
22
#include "kernels/matmul_16_6_1.h"
33
#include "kernels/matmul_16_6_k.h"
4+
#include "kernels/matmul_16m_4n_k.h"
45
#include "Kernel.h"
56

67
mini_jit::Brgemm::error_t mini_jit::Brgemm::generate(uint32_t m, uint32_t n, uint32_t k, uint32_t br_size, uint32_t trans_a, uint32_t trans_b, uint32_t trans_c, dtype_t dtype)
@@ -9,7 +10,7 @@ mini_jit::Brgemm::error_t mini_jit::Brgemm::generate(uint32_t m, uint32_t n, uin
910
{
1011
return error_t::err_wrong_dtype;
1112
}
12-
if (m != 16 || n != 6)
13+
if (m % 16 != 0 || !(n % 4 == 0 || n == 6))
1314
{
1415
return error_t::err_wrong_dimension;
1516
}
@@ -22,12 +23,21 @@ mini_jit::Brgemm::error_t mini_jit::Brgemm::generate(uint32_t m, uint32_t n, uin
2223
return error_t::err_batch_reduce_size_not_supported;
2324
}
2425

25-
if (k == 1)
26+
if (m == 16 && n == 6 && k == 1)
2627
{
27-
kernels::matmul_16_6_1(native_kernel);
28+
kernels::matmul_16_6_1(native_kernel);
2829
}
29-
30-
kernels::matmul_16_6_k(native_kernel, k);
30+
31+
if (m == 16 && n == 6)
32+
{
33+
kernels::matmul_16_6_k(native_kernel, k);
34+
}
35+
36+
if (m % 16 == 0 && n % 4 == 0)
37+
{
38+
kernels::matmul_16m_4n_k(native_kernel, m / 16, n / 4, k);
39+
}
40+
3141

3242
native_kernel.set_kernel();
3343
kernel = reinterpret_cast<kernel_t>(const_cast<void*>(native_kernel.get_kernel())); // Properly cast from const void* to kernel_t

src/main/arm_instructions/base/base_all.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,5 +13,6 @@
1313
#include "orr.h"
1414
#include "mov.h"
1515
#include "movz.h"
16+
#include "madd.h"
1617

1718
#endif // MINI_JIT_ARM_INSTRUCTIONS_BASE_ALL_H
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
#ifndef MINI_JIT_ARM_INSTRUCTIONS_BASE_MADD_H
2+
#define MINI_JIT_ARM_INSTRUCTIONS_BASE_MADD_H
3+
4+
#include <cstdint>
5+
#include "../../release_assert.h"
6+
#include "../register.h"
7+
8+
namespace mini_jit
9+
{
10+
namespace arm_instructions
11+
{
12+
namespace internal
13+
{
14+
15+
constexpr uint32_t madd(const uint32_t Rd, const uint32_t Rn, const uint32_t Rm, const uint32_t Ra,
16+
const bool is64bit)
17+
{
18+
release_assert((Rd & mask5) == Rd, "Rd is only allowed to have a size of 5 bit.");
19+
release_assert((Rn & mask5) == Rn, "Rn is only allowed to have a size of 5 bit.");
20+
release_assert((Rm & mask5) == Rm, "Rm is only allowed to have a size of 5 bit.");
21+
release_assert((Ra & mask5) == Ra, "Ra is only allowed to have a size of 5 bit.");
22+
23+
uint32_t add = 0;
24+
add |= (is64bit & mask1) << 31;
25+
add |= 0b0011011000 << 21;
26+
add |= (Rm & mask5) << 16;
27+
add |= 0b0 << 15;
28+
add |= (Ra & mask5) << 10;
29+
add |= (Rn & mask5) << 5;
30+
add |= (Rd & mask5) << 0;
31+
return add;
32+
}
33+
34+
} // namespace internal
35+
36+
constexpr uint32_t madd(const R32Bit Wd, const R32Bit Wn, const R32Bit Wm, const R32Bit Wa)
37+
{
38+
return internal::madd(static_cast<uint32_t>(Wd), static_cast<uint32_t>(Wn), static_cast<uint32_t>(Wm),
39+
static_cast<uint32_t>(Wa), false);
40+
}
41+
42+
constexpr uint32_t madd(const R64Bit Xd, const R64Bit Xn, const R64Bit Xm, const R64Bit Xa)
43+
{
44+
return internal::madd(static_cast<uint32_t>(Xd), static_cast<uint32_t>(Xn), static_cast<uint32_t>(Xm),
45+
static_cast<uint32_t>(Xa), true);
46+
}
47+
48+
} // namespace arm_instructions
49+
} // namespace mini_jit
50+
51+
#endif // MINI_JIT_ARM_INSTRUCTIONS_BASE_MADD_H

0 commit comments

Comments
 (0)