|
4 | 4 | #include "../release_assert.h" |
5 | 5 |
|
6 | 6 | void mini_jit::kernels::matmul_lt16_4n_k(mini_jit::Kernel &kernel, const uint32_t n_loop_4, const uint32_t k_loop, |
7 | | - const uint32_t m_loop_rest) |
| 7 | + const uint32_t m_loop_rest, const bool use_init_and_end) |
8 | 8 | { |
9 | 9 | using namespace mini_jit::arm_instructions; |
10 | 10 |
|
11 | 11 | release_assert(n_loop_4 != 0, "Cannot proccess matrix with k loop of 0."); |
12 | 12 | release_assert(k_loop != 0, "Cannot proccess matrix with k loop of 0."); |
13 | 13 | release_assert(m_loop_rest != 0, "Cannot create a matrix with a rest of m equal to 0!"); |
14 | 14 | release_assert(m_loop_rest <= 15, "Cannot create a matrix with a rest of m larger than 15!"); |
15 | | - // Hold the number of instruction to jump for each loop |
16 | | - int32_t jump_N_loop = 23; // start value = amount of instructions outside of control flow |
17 | 15 |
|
18 | | - kernel.add({ |
19 | | - // /** |
20 | | - // * @param x0 = a pointer to column-major 64x64 matrix A. |
21 | | - // * @param x1 = b pointer to column-major 64x64 matrix B. |
22 | | - // * @param x2 = c pointer to column-major 64x64 matrix C. |
23 | | - // * @param x3 = lda leading dimension of A. |
24 | | - // * @param x4 = ldb leading dimension of B. |
25 | | - // * @param x5 = ldc leading dimension of C. |
26 | | - // **/ |
27 | | - // .type matmul_64_48_64, %function |
28 | | - // .global matmul_64_48_64 |
29 | | - // matmul_64_48_64: |
30 | | - |
31 | | - // // Procedural Call Standard |
32 | | - // // save frame pointer and link register |
33 | | - // // stp fp, lr, [sp, #-16]! |
34 | | - // // update frame pointer to current stack pointer |
35 | | - // // mov fp, sp |
36 | | - |
37 | | - // // save callee-saved registers |
38 | | - // // stp x19, x20, [sp, #-16]! |
39 | | - // // stp x21, x22, [sp, #-16]! |
40 | | - // // stp x23, x24, [sp, #-16]! |
41 | | - // // stp x25, x26, [sp, #-16]! |
42 | | - // // stp x27, x28, [sp, #-16]! |
43 | | - |
44 | | - stpPre(d8, d9, sp, -16), // stp d8, d9, [sp, #-16]! |
45 | | - // // stp d10, d11, [sp, #-16]! |
46 | | - // // stp d12, d13, [sp, #-16]! |
47 | | - // // stp d14, d15, [sp, #-16]! |
48 | | - |
49 | | - // // Offset the used leading dimension by the size of floats |
50 | | - lsl(x3, x3, 2), // lsl x3, x3, #2 // x3 * 4 = x3 * sizeof(float) |
51 | | - lsl(x4, x4, 2), // lsl x4, x4, #2 // x4 * 4 = x4 * sizeof(float) |
52 | | - lsl(x5, x5, 2), // lsl x5, x5, #2 // x5 * 4 = x5 * sizeof(float) |
53 | | - |
54 | | - mov(x6, x1), // mov x6, x1 // Store the initial value of x1, to be restored in the K loop iteration |
55 | | - mov(x7, x2), // mov x7, x2 // Store the initial value of x2, to be restored in the K loop iteration |
56 | | - |
57 | | - mov(x8, x0), // mov x8, x0 // Store the initial value of x0, to be restored in the M loop iteration |
58 | | - mov(x9, x1), // mov x9, x1 // Store the initial value of x1, to be restored in the M loop iteration |
59 | | - |
60 | | - mov(x10, x0), // mov x10, x0 // Store the initial value of x0, to be restored in the N loop iteration |
61 | | - mov(x11, x2), // mov x11, x2 // Store the initial value of x2, to bes restored in the N loop iteration |
62 | | - mov(x12, 4), // mov x12, #4 // hold the size of N that are processed in one loop, needed for offset calculation |
63 | | - }); |
| 16 | + if (use_init_and_end) |
| 17 | + { |
| 18 | + kernel.add({ |
| 19 | + // /** |
| 20 | + // * @param x0 = a pointer to column-major 64x64 matrix A. |
| 21 | + // * @param x1 = b pointer to column-major 64x64 matrix B. |
| 22 | + // * @param x2 = c pointer to column-major 64x64 matrix C. |
| 23 | + // * @param x3 = lda leading dimension of A. |
| 24 | + // * @param x4 = ldb leading dimension of B. |
| 25 | + // * @param x5 = ldc leading dimension of C. |
| 26 | + // **/ |
| 27 | + // .type matmul_64_48_64, %function |
| 28 | + // .global matmul_64_48_64 |
| 29 | + // matmul_64_48_64: |
| 30 | + |
| 31 | + // // Procedural Call Standard |
| 32 | + // // save frame pointer and link register |
| 33 | + // // stp fp, lr, [sp, #-16]! |
| 34 | + // // update frame pointer to current stack pointer |
| 35 | + // // mov fp, sp |
| 36 | + |
| 37 | + // // save callee-saved registers |
| 38 | + // // stp x19, x20, [sp, #-16]! |
| 39 | + // // stp x21, x22, [sp, #-16]! |
| 40 | + // // stp x23, x24, [sp, #-16]! |
| 41 | + // // stp x25, x26, [sp, #-16]! |
| 42 | + // // stp x27, x28, [sp, #-16]! |
| 43 | + |
| 44 | + stpPre(d8, d9, sp, -16), // stp d8, d9, [sp, #-16]! |
| 45 | + // // stp d10, d11, [sp, #-16]! |
| 46 | + // // stp d12, d13, [sp, #-16]! |
| 47 | + // // stp d14, d15, [sp, #-16]! |
| 48 | + |
| 49 | + // // Offset the used leading dimension by the size of floats |
| 50 | + lsl(x3, x3, 2), // lsl x3, x3, #2 // x3 * 4 = x3 * sizeof(float) |
| 51 | + lsl(x4, x4, 2), // lsl x4, x4, #2 // x4 * 4 = x4 * sizeof(float) |
| 52 | + lsl(x5, x5, 2), // lsl x5, x5, #2 // x5 * 4 = x5 * sizeof(float) |
| 53 | + |
| 54 | + mov(x6, x1), // mov x6, x1 // Store the initial value of x1, to be restored in the K loop iteration |
| 55 | + mov(x7, x2), // mov x7, x2 // Store the initial value of x2, to be restored in the K loop iteration |
| 56 | + |
| 57 | + mov(x8, x0), // mov x8, x0 // Store the initial value of x0, to be restored in the M loop iteration |
| 58 | + mov(x9, x1), // mov x9, x1 // Store the initial value of x1, to be restored in the M loop iteration |
| 59 | + |
| 60 | + mov(x10, x0), // mov x10, x0 // Store the initial value of x0, to be restored in the N loop iteration |
| 61 | + mov(x11, x2), // mov x11, x2 // Store the initial value of x2, to bes restored in the N loop iteration |
| 62 | + mov(x12, 4), // mov x12, #4 // hold the size of N that are processed in one loop, needed for offset calculation |
| 63 | + }); |
| 64 | + } |
64 | 65 |
|
65 | 66 | // ======================================================================================== |
66 | 67 | // Rest Calculation of m loop |
67 | 68 | // ======================================================================================== |
68 | 69 |
|
| 70 | + // Hold the number of instruction to jump for each loop |
| 71 | + int32_t jump_N_loop = 23; // start value = amount of instructions outside of control flow |
| 72 | + |
69 | 73 | kernel.add({ |
70 | 74 | mov(x17, n_loop_4), // mov x17, #12 // x17 iterator for N loop |
71 | 75 | // matmul_loop_over_N: |
@@ -733,26 +737,33 @@ void mini_jit::kernels::matmul_lt16_4n_k(mini_jit::Kernel &kernel, const uint32_ |
733 | 737 |
|
734 | 738 | // // Loop back to N |
735 | 739 | cbnz(x17, -jump_N_loop * 4), // cbnz x17, matmul_loop_over_N |
| 740 | + }); |
736 | 741 |
|
737 | | - // // Procedural Call Standard |
738 | | - // // restore callee-saved registers |
739 | | - // // ldp d14, d15, [sp], #16 |
740 | | - // // ldp d12, d13, [sp], #16 |
741 | | - // // ldp d10, d11, [sp], #16 |
742 | | - ldpPost(d8, d9, sp, 16), // ldp d8, d9, [sp], #16 |
| 742 | + if (use_init_and_end) |
| 743 | + { |
743 | 744 |
|
744 | | - // // ldp x27, x28, [sp], #16 |
745 | | - // // ldp x25, x26, [sp], #16 |
746 | | - // // ldp x23, x24, [sp], #16 |
747 | | - // // ldp x21, x22, [sp], #16 |
748 | | - // // ldp x19, x20, [sp], #16 |
| 745 | + kernel.add({ |
749 | 746 |
|
750 | | - // // restore frame pointer and link register |
751 | | - // // ldp fp, lr, [sp], #16 |
| 747 | + // // Procedural Call Standard |
| 748 | + // // restore callee-saved registers |
| 749 | + // // ldp d14, d15, [sp], #16 |
| 750 | + // // ldp d12, d13, [sp], #16 |
| 751 | + // // ldp d10, d11, [sp], #16 |
| 752 | + ldpPost(d8, d9, sp, 16), // ldp d8, d9, [sp], #16 |
752 | 753 |
|
753 | | - ret() // ret |
754 | | - // .size matmul_64_48_64, (. - matmul_64_48_64) |
755 | | - }); |
| 754 | + // // ldp x27, x28, [sp], #16 |
| 755 | + // // ldp x25, x26, [sp], #16 |
| 756 | + // // ldp x23, x24, [sp], #16 |
| 757 | + // // ldp x21, x22, [sp], #16 |
| 758 | + // // ldp x19, x20, [sp], #16 |
| 759 | + |
| 760 | + // // restore frame pointer and link register |
| 761 | + // // ldp fp, lr, [sp], #16 |
| 762 | + |
| 763 | + ret() // ret |
| 764 | + // .size matmul_64_48_64, (. - matmul_64_48_64) |
| 765 | + }); |
| 766 | + } |
756 | 767 |
|
757 | 768 | #ifdef SAVE_JITS_TO_FILE |
758 | 769 | kernel.write("matmul_lt16_4n_k.bin"); |
|
0 commit comments