Skip to content

Commit 30e988c

Browse files
committed
feat: add matmul_lt16_lt4nRest_k & extended matmul_16mRest_lt4nRest :)
1 parent 20c6797 commit 30e988c

12 files changed

+1202
-134
lines changed

CMakeLists.txt

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,10 +82,12 @@ set(KERNEL_FILES
8282
matmul_16m_lt4nRest_k.cpp
8383
matmul_16mRest_4n_k.h
8484
matmul_16mRest_4n_k.cpp
85-
matmul_16mRest_4nRest_k.h
86-
matmul_16mRest_4nRest_k.cpp
85+
matmul_16mRest_lt4nRest_k.h
86+
matmul_16mRest_lt4nRest_k.cpp
8787
matmul_lt16_4n_k.h
8888
matmul_lt16_4n_k.cpp
89+
matmul_lt16_lt4nRest_k.h
90+
matmul_lt16_lt4nRest_k.cpp
8991
)
9092

9193
set(ARM_INSTRUCTION_FILES
@@ -127,9 +129,11 @@ set(TEST_KERNELS
127129
matmul_16_6_1.test.cpp
128130
matmul_16_6_k.test.cpp
129131
matmul_16m_4n_k.test.cpp
132+
matmul_16mRest_4n_k.test.cpp
130133
matmul_16m_lt4nRest_k.test.cpp
131-
matmul_16mRest_4nRest_k.test.cpp
134+
matmul_16mRest_lt4nRest_k.test.cpp
132135
matmul_lt16_4n_k.test.cpp
136+
matmul_lt16_lt4nRest_k.test.cpp
133137
)
134138

135139
set(TEST_ARM_INSTRUCTION_FILES

src/main/Brgemm.cpp

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ mini_jit::Brgemm::error_t mini_jit::Brgemm::generate(uint32_t m, uint32_t n, uin
1111
{
1212
return error_t::err_wrong_dtype;
1313
}
14-
if (m < 16 || n < 4 || m == 0 || n == 0 || k == 0)
14+
if (m == 0 || n == 0 || k == 0)
1515
{
1616
return error_t::err_wrong_dimension;
1717
}
@@ -83,10 +83,25 @@ void mini_jit::Brgemm::fill_with_matmuls_no_batch_dim_column_major_fp32(uint32_t
8383
return;
8484
}
8585

86-
if (m >= 16 && n >= 4)
86+
if (m < 16 && n >= 4 && n % 4 == 0)
87+
{
88+
kernels::matmul_lt16_4n_k(native_kernel, n / 4, k, m % 16);
89+
return;
90+
}
91+
92+
if (m >= 16)
93+
{
94+
// At this point m % 16 != 0 and n % 4 != 0
95+
kernels::matmul_16mRest_lt4nRest_k(native_kernel, m / 16, n / 4, k, m % 16, n % 4);
96+
return;
97+
}
98+
99+
if (m < 16)
87100
{
88101
// At this point m % 16 != 0 and n % 4 != 0
89-
kernels::matmul_16mRest_4nRest_k(native_kernel, m / 16, n / 4, k, m % 16, n % 4);
102+
kernels::matmul_lt16_lt4nRest_k(native_kernel, n / 4, k, m % 16, n % 4);
90103
return;
91104
}
105+
106+
throw std::logic_error(std::format("Unhandled combination found for MxNxK matmul: m='{}', n='{}', k='{}'", m, n, k));
92107
}

src/main/kernels/matmul_16mRest_4nRest_k.cpp renamed to src/main/kernels/matmul_16mRest_lt4nRest_k.cpp

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,25 @@
1-
#include "matmul_16mRest_4nRest_k.h"
1+
#include "matmul_16mRest_lt4nRest_k.h"
22
#include "../Kernel.h"
33
#include "../arm_instructions/arm_all.h"
44
#include "../release_assert.h"
55
#include "matmul_16mRest_4n_k.h"
66
#include "matmul_16m_lt4nRest_k.h"
77

8-
void mini_jit::kernels::matmul_16mRest_4nRest_k(mini_jit::Kernel &kernel, const uint32_t m_loop_16, const uint32_t n_loop_4,
9-
const uint32_t k_loop, const uint32_t m_loop_rest, const uint32_t n_loop_rest)
8+
void mini_jit::kernels::matmul_16mRest_lt4nRest_k(mini_jit::Kernel &kernel, const uint32_t m_loop_16, const uint32_t n_loop_4,
9+
const uint32_t k_loop, const uint32_t m_loop_rest, const uint32_t n_loop_rest)
1010
{
1111
using namespace mini_jit::arm_instructions;
1212

1313
release_assert(m_loop_16 != 0, "Cannot proccess matrix with m loop of 0.");
14-
release_assert(n_loop_4 != 0, "Cannot proccess matrix with n loop of 0.");
1514
release_assert(k_loop != 0, "Cannot proccess matrix with k loop of 0.");
1615
release_assert(m_loop_rest != 0, "Cannot create a matrix with a rest of m equal to 0!");
1716
release_assert(m_loop_rest <= 15, "Cannot create a matrix with a rest of m larger than 15!");
1817
release_assert(n_loop_rest != 0, "Cannot create a matrix with a rest of n equal to 0!");
1918
release_assert(n_loop_rest <= 3, "Cannot create a matrix with a rest of n larger than 3!");
2019

2120
// Idea: Division of the matrix into sub-matrices and calculated in the following order.
21+
// 1. matmul_lt16_4n_k is omitted if n is less than 4;
22+
//
2223
// N dimension
2324
// ←---------------------------------------------------→
2425
// ===================================================== ↑
@@ -79,7 +80,10 @@ void mini_jit::kernels::matmul_16mRest_4nRest_k(mini_jit::Kernel &kernel, const
7980
// ========================================================================================
8081
// Calculate m + rest but n is multiple of 4
8182
// ========================================================================================
82-
matmul_16mRest_4n_k(kernel, m_loop_16, n_loop_4, k_loop, m_loop_rest, false);
83+
if (n_loop_4 != 0)
84+
{
85+
matmul_16mRest_4n_k(kernel, m_loop_16, n_loop_4, k_loop, m_loop_rest, false);
86+
}
8387

8488
// Offset to the next matrix block
8589
// Here we want to start with the initial m value but n should be offset by the already calculated amount.
@@ -91,7 +95,7 @@ void mini_jit::kernels::matmul_16mRest_4nRest_k(mini_jit::Kernel &kernel, const
9195
matmul_16m_lt4nRest_k(kernel, m_loop_16, 0, k_loop, n_loop_rest, false);
9296

9397
// Now we want to make sure to not restore the position of the m as it is in the right position.
94-
// Therefore we should restore the register above the m_loop
98+
// Therefore we should restore the register below the m_loop
9599

96100
// ========================================================================================
97101
// Rest Calculation of m and n loop

src/main/kernels/matmul_16mRest_4nRest_k.h renamed to src/main/kernels/matmul_16mRest_lt4nRest_k.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
#ifndef MINI_JIT_KERNELS_MATMUL_16MRest_4NRest_K_H
2-
#define MINI_JIT_KERNELS_MATMUL_16MRest_4NRest_K_H
1+
#ifndef MINI_JIT_KERNELS_MATMUL_16MRest_LT4NRest_K_H
2+
#define MINI_JIT_KERNELS_MATMUL_16MRest_LT4NRest_K_H
33

44
#include "../Kernel.h"
55
#include <cstdint>
@@ -19,9 +19,9 @@ namespace mini_jit
1919
* @param m_loop_rest The rest/remainder of the m loop that is not dividable by 16.
2020
* @param n_loop_rest The rest/remainder of the n loop that is not dividable by 4.
2121
*/
22-
void matmul_16mRest_4nRest_k(mini_jit::Kernel &kernel, const uint32_t m_loop_16, const uint32_t n_loop_4, const uint32_t k_loop,
23-
const uint32_t m_loop_rest, const uint32_t n_loop_rest);
22+
void matmul_16mRest_lt4nRest_k(mini_jit::Kernel &kernel, const uint32_t m_loop_16, const uint32_t n_loop_4, const uint32_t k_loop,
23+
const uint32_t m_loop_rest, const uint32_t n_loop_rest);
2424

2525
} // namespace kernels
2626
} // namespace mini_jit
27-
#endif // MINI_JIT_KERNELS_MATMUL_16MRest_4NRest_K_H
27+
#endif // MINI_JIT_KERNELS_MATMUL_16MRest_LT4NRest_K_H

src/main/kernels/matmul_lt16_4n_k.cpp

Lines changed: 76 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -4,68 +4,72 @@
44
#include "../release_assert.h"
55

66
void mini_jit::kernels::matmul_lt16_4n_k(mini_jit::Kernel &kernel, const uint32_t n_loop_4, const uint32_t k_loop,
7-
const uint32_t m_loop_rest)
7+
const uint32_t m_loop_rest, const bool use_init_and_end)
88
{
99
using namespace mini_jit::arm_instructions;
1010

1111
release_assert(n_loop_4 != 0, "Cannot proccess matrix with k loop of 0.");
1212
release_assert(k_loop != 0, "Cannot proccess matrix with k loop of 0.");
1313
release_assert(m_loop_rest != 0, "Cannot create a matrix with a rest of m equal to 0!");
1414
release_assert(m_loop_rest <= 15, "Cannot create a matrix with a rest of m larger than 15!");
15-
// Hold the number of instruction to jump for each loop
16-
int32_t jump_N_loop = 23; // start value = amount of instructions outside of control flow
1715

18-
kernel.add({
19-
// /**
20-
// * @param x0 = a pointer to column-major 64x64 matrix A.
21-
// * @param x1 = b pointer to column-major 64x64 matrix B.
22-
// * @param x2 = c pointer to column-major 64x64 matrix C.
23-
// * @param x3 = lda leading dimension of A.
24-
// * @param x4 = ldb leading dimension of B.
25-
// * @param x5 = ldc leading dimension of C.
26-
// **/
27-
// .type matmul_64_48_64, %function
28-
// .global matmul_64_48_64
29-
// matmul_64_48_64:
30-
31-
// // Procedural Call Standard
32-
// // save frame pointer and link register
33-
// // stp fp, lr, [sp, #-16]!
34-
// // update frame pointer to current stack pointer
35-
// // mov fp, sp
36-
37-
// // save callee-saved registers
38-
// // stp x19, x20, [sp, #-16]!
39-
// // stp x21, x22, [sp, #-16]!
40-
// // stp x23, x24, [sp, #-16]!
41-
// // stp x25, x26, [sp, #-16]!
42-
// // stp x27, x28, [sp, #-16]!
43-
44-
stpPre(d8, d9, sp, -16), // stp d8, d9, [sp, #-16]!
45-
// // stp d10, d11, [sp, #-16]!
46-
// // stp d12, d13, [sp, #-16]!
47-
// // stp d14, d15, [sp, #-16]!
48-
49-
// // Offset the used leading dimension by the size of floats
50-
lsl(x3, x3, 2), // lsl x3, x3, #2 // x3 * 4 = x3 * sizeof(float)
51-
lsl(x4, x4, 2), // lsl x4, x4, #2 // x4 * 4 = x4 * sizeof(float)
52-
lsl(x5, x5, 2), // lsl x5, x5, #2 // x5 * 4 = x5 * sizeof(float)
53-
54-
mov(x6, x1), // mov x6, x1 // Store the initial value of x1, to be restored in the K loop iteration
55-
mov(x7, x2), // mov x7, x2 // Store the initial value of x2, to be restored in the K loop iteration
56-
57-
mov(x8, x0), // mov x8, x0 // Store the initial value of x0, to be restored in the M loop iteration
58-
mov(x9, x1), // mov x9, x1 // Store the initial value of x1, to be restored in the M loop iteration
59-
60-
mov(x10, x0), // mov x10, x0 // Store the initial value of x0, to be restored in the N loop iteration
61-
mov(x11, x2), // mov x11, x2 // Store the initial value of x2, to bes restored in the N loop iteration
62-
mov(x12, 4), // mov x12, #4 // hold the size of N that are processed in one loop, needed for offset calculation
63-
});
16+
if (use_init_and_end)
17+
{
18+
kernel.add({
19+
// /**
20+
// * @param x0 = a pointer to column-major 64x64 matrix A.
21+
// * @param x1 = b pointer to column-major 64x64 matrix B.
22+
// * @param x2 = c pointer to column-major 64x64 matrix C.
23+
// * @param x3 = lda leading dimension of A.
24+
// * @param x4 = ldb leading dimension of B.
25+
// * @param x5 = ldc leading dimension of C.
26+
// **/
27+
// .type matmul_64_48_64, %function
28+
// .global matmul_64_48_64
29+
// matmul_64_48_64:
30+
31+
// // Procedural Call Standard
32+
// // save frame pointer and link register
33+
// // stp fp, lr, [sp, #-16]!
34+
// // update frame pointer to current stack pointer
35+
// // mov fp, sp
36+
37+
// // save callee-saved registers
38+
// // stp x19, x20, [sp, #-16]!
39+
// // stp x21, x22, [sp, #-16]!
40+
// // stp x23, x24, [sp, #-16]!
41+
// // stp x25, x26, [sp, #-16]!
42+
// // stp x27, x28, [sp, #-16]!
43+
44+
stpPre(d8, d9, sp, -16), // stp d8, d9, [sp, #-16]!
45+
// // stp d10, d11, [sp, #-16]!
46+
// // stp d12, d13, [sp, #-16]!
47+
// // stp d14, d15, [sp, #-16]!
48+
49+
// // Offset the used leading dimension by the size of floats
50+
lsl(x3, x3, 2), // lsl x3, x3, #2 // x3 * 4 = x3 * sizeof(float)
51+
lsl(x4, x4, 2), // lsl x4, x4, #2 // x4 * 4 = x4 * sizeof(float)
52+
lsl(x5, x5, 2), // lsl x5, x5, #2 // x5 * 4 = x5 * sizeof(float)
53+
54+
mov(x6, x1), // mov x6, x1 // Store the initial value of x1, to be restored in the K loop iteration
55+
mov(x7, x2), // mov x7, x2 // Store the initial value of x2, to be restored in the K loop iteration
56+
57+
mov(x8, x0), // mov x8, x0 // Store the initial value of x0, to be restored in the M loop iteration
58+
mov(x9, x1), // mov x9, x1 // Store the initial value of x1, to be restored in the M loop iteration
59+
60+
mov(x10, x0), // mov x10, x0 // Store the initial value of x0, to be restored in the N loop iteration
61+
mov(x11, x2), // mov x11, x2 // Store the initial value of x2, to bes restored in the N loop iteration
62+
mov(x12, 4), // mov x12, #4 // hold the size of N that are processed in one loop, needed for offset calculation
63+
});
64+
}
6465

6566
// ========================================================================================
6667
// Rest Calculation of m loop
6768
// ========================================================================================
6869

70+
// Hold the number of instruction to jump for each loop
71+
int32_t jump_N_loop = 23; // start value = amount of instructions outside of control flow
72+
6973
kernel.add({
7074
mov(x17, n_loop_4), // mov x17, #12 // x17 iterator for N loop
7175
// matmul_loop_over_N:
@@ -733,26 +737,33 @@ void mini_jit::kernels::matmul_lt16_4n_k(mini_jit::Kernel &kernel, const uint32_
733737

734738
// // Loop back to N
735739
cbnz(x17, -jump_N_loop * 4), // cbnz x17, matmul_loop_over_N
740+
});
736741

737-
// // Procedural Call Standard
738-
// // restore callee-saved registers
739-
// // ldp d14, d15, [sp], #16
740-
// // ldp d12, d13, [sp], #16
741-
// // ldp d10, d11, [sp], #16
742-
ldpPost(d8, d9, sp, 16), // ldp d8, d9, [sp], #16
742+
if (use_init_and_end)
743+
{
743744

744-
// // ldp x27, x28, [sp], #16
745-
// // ldp x25, x26, [sp], #16
746-
// // ldp x23, x24, [sp], #16
747-
// // ldp x21, x22, [sp], #16
748-
// // ldp x19, x20, [sp], #16
745+
kernel.add({
749746

750-
// // restore frame pointer and link register
751-
// // ldp fp, lr, [sp], #16
747+
// // Procedural Call Standard
748+
// // restore callee-saved registers
749+
// // ldp d14, d15, [sp], #16
750+
// // ldp d12, d13, [sp], #16
751+
// // ldp d10, d11, [sp], #16
752+
ldpPost(d8, d9, sp, 16), // ldp d8, d9, [sp], #16
752753

753-
ret() // ret
754-
// .size matmul_64_48_64, (. - matmul_64_48_64)
755-
});
754+
// // ldp x27, x28, [sp], #16
755+
// // ldp x25, x26, [sp], #16
756+
// // ldp x23, x24, [sp], #16
757+
// // ldp x21, x22, [sp], #16
758+
// // ldp x19, x20, [sp], #16
759+
760+
// // restore frame pointer and link register
761+
// // ldp fp, lr, [sp], #16
762+
763+
ret() // ret
764+
// .size matmul_64_48_64, (. - matmul_64_48_64)
765+
});
766+
}
756767

757768
#ifdef SAVE_JITS_TO_FILE
758769
kernel.write("matmul_lt16_4n_k.bin");

src/main/kernels/matmul_lt16_4n_k.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,11 @@ namespace mini_jit
1616
* @param n_loop_4 The repetitions of the n block of size 4.
1717
* @param k_loop The loops in the k dimensions.
1818
* @param m_loop_rest The rest/remainder of the m loop that is not dividable by 16.
19+
* @param use_init_and_end Indicates if the procedural call standard, initializing setup and the ret instruction are used. Defaults to
20+
* true.
1921
*/
20-
void matmul_lt16_4n_k(mini_jit::Kernel &kernel, const uint32_t n_loop_4, const uint32_t k_loop, const uint32_t m_loop_rest);
22+
void matmul_lt16_4n_k(mini_jit::Kernel &kernel, const uint32_t n_loop_4, const uint32_t k_loop, const uint32_t m_loop_rest,
23+
const bool use_init_and_end = true);
2124

2225
} // namespace kernels
2326
} // namespace mini_jit

0 commit comments

Comments
 (0)