|
1 | 1 | #include "Brgemm.h" |
2 | | -#include "kernels/matmuls_all.h" |
3 | 2 | #include "Kernel.h" |
4 | | -#include <stdexcept> |
| 3 | +#include "kernels/matmuls_all.h" |
5 | 4 | #include <format> |
| 5 | +#include <stdexcept> |
6 | 6 |
|
7 | | -mini_jit::Brgemm::error_t mini_jit::Brgemm::generate(uint32_t m, uint32_t n, uint32_t k, uint32_t br_size, |
8 | | - uint32_t trans_a, uint32_t trans_b, uint32_t trans_c, dtype_t dtype) |
| 7 | +mini_jit::Brgemm::error_t mini_jit::Brgemm::generate(uint32_t m, uint32_t n, uint32_t k, uint32_t br_size, uint32_t trans_a, |
| 8 | + uint32_t trans_b, uint32_t trans_c, dtype_t dtype) |
9 | 9 | { |
10 | | - if (dtype != dtype_t::fp32) |
11 | | - { |
12 | | - return error_t::err_wrong_dtype; |
13 | | - } |
14 | | - if (m % 16 != 0 || (n < 4)) |
15 | | - { |
16 | | - return error_t::err_wrong_dimension; |
17 | | - } |
18 | | - if ((trans_a + trans_b + trans_c) != 0) |
19 | | - { |
20 | | - return error_t::err_row_major_order_not_supported; |
21 | | - } |
22 | | - if (br_size != 1) |
23 | | - { |
24 | | - return error_t::err_batch_reduce_size_not_supported; |
25 | | - } |
| 10 | + if (dtype != dtype_t::fp32) |
| 11 | + { |
| 12 | + return error_t::err_wrong_dtype; |
| 13 | + } |
| 14 | + if (m % 16 != 0 || (n < 4)) |
| 15 | + { |
| 16 | + return error_t::err_wrong_dimension; |
| 17 | + } |
| 18 | + if ((trans_a + trans_b + trans_c) != 0) |
| 19 | + { |
| 20 | + return error_t::err_row_major_order_not_supported; |
| 21 | + } |
| 22 | + if (br_size != 1) |
| 23 | + { |
| 24 | + return error_t::err_batch_reduce_size_not_supported; |
| 25 | + } |
26 | 26 |
|
27 | | - if (br_size == 1 && (trans_a + trans_b + trans_c) == 0 && dtype == dtype_t::fp32) |
28 | | - { |
29 | | - fill_with_matmuls_no_batch_dim_column_major_fp32(m, n, k); |
30 | | - } |
31 | | - else |
32 | | - { |
33 | | - throw std::logic_error(std::format( |
34 | | - "Unhandled parameter combination found: m='{}', n='{}', k='{}', br_size='{}', trans_a='{}', trans_b='{}', " |
35 | | - "trans_c = '{}', dtype = '{}'", m, n, k, br_size, trans_a, trans_b, trans_c, static_cast<int32_t>(dtype))); |
36 | | - } |
| 27 | + if (br_size == 1 && (trans_a + trans_b + trans_c) == 0 && dtype == dtype_t::fp32) |
| 28 | + { |
| 29 | + fill_with_matmuls_no_batch_dim_column_major_fp32(m, n, k); |
| 30 | + } |
| 31 | + else |
| 32 | + { |
| 33 | + throw std::logic_error( |
| 34 | + std::format("Unhandled parameter combination found: m='{}', n='{}', k='{}', br_size='{}', trans_a='{}', trans_b='{}', " |
| 35 | + "trans_c = '{}', dtype = '{}'", |
| 36 | + m, n, k, br_size, trans_a, trans_b, trans_c, static_cast<int32_t>(dtype))); |
| 37 | + } |
37 | 38 |
|
| 39 | + native_kernel.set_kernel(); |
| 40 | + kernel = reinterpret_cast<kernel_t>(const_cast<void *>(native_kernel.get_kernel())); // Properly cast from const void* to kernel_t |
38 | 41 |
|
39 | | - |
40 | | - native_kernel.set_kernel(); |
41 | | - kernel = reinterpret_cast<kernel_t>(const_cast<void*>(native_kernel.get_kernel())); // Properly cast from const void* to kernel_t |
42 | | - |
43 | | - return error_t::success; |
| 42 | + return error_t::success; |
44 | 43 | } |
45 | 44 |
|
46 | 45 | mini_jit::Brgemm::kernel_t mini_jit::Brgemm::get_kernel() const |
47 | 46 | { |
48 | | - return kernel; |
| 47 | + return kernel; |
49 | 48 | } |
50 | 49 |
|
51 | 50 | void mini_jit::Brgemm::fill_with_matmuls_no_batch_dim_column_major_fp32(uint32_t m, uint32_t n, uint32_t k) |
52 | 51 | { |
53 | | - // Always sort from the specific to the more general case |
| 52 | + // Always sort from the specific to the more general case |
54 | 53 |
|
55 | | - if (m == 16 && n == 6 && k == 1) |
56 | | - { |
57 | | - kernels::matmul_16_6_1(native_kernel); |
58 | | - return; |
59 | | - } |
| 54 | + if (m == 16 && n == 6 && k == 1) |
| 55 | + { |
| 56 | + kernels::matmul_16_6_1(native_kernel); |
| 57 | + return; |
| 58 | + } |
60 | 59 |
|
61 | | - if (m == 16 && n == 6) |
62 | | - { |
63 | | - kernels::matmul_16_6_k(native_kernel, k); |
64 | | - return; |
65 | | - } |
| 60 | + if (m == 16 && n == 6) |
| 61 | + { |
| 62 | + kernels::matmul_16_6_k(native_kernel, k); |
| 63 | + return; |
| 64 | + } |
66 | 65 |
|
67 | | - if (m >= 16 && m % 16 == 0 && n >= 4 && n % 4 == 0) |
68 | | - { |
69 | | - kernels::matmul_16m_4n_k(native_kernel, m / 16, n / 4, k); |
70 | | - return; |
71 | | - } |
| 66 | + if (m >= 16 && m % 16 == 0 && n >= 4 && n % 4 == 0) |
| 67 | + { |
| 68 | + kernels::matmul_16m_4n_k(native_kernel, m / 16, n / 4, k); |
| 69 | + return; |
| 70 | + } |
72 | 71 |
|
73 | | - if (m >= 16 && m % 16 == 0 && n >= 4) |
74 | | - { |
75 | | - // At this point n % 4 != 0 |
76 | | - kernels::matmul_16m_4nRest_k(native_kernel, m / 16, n / 4, k, n % 4); |
77 | | - return; |
78 | | - } |
| 72 | + if (m >= 16 && m % 16 == 0 && n >= 4) |
| 73 | + { |
| 74 | + // At this point n % 4 != 0 |
| 75 | + kernels::matmul_16m_4nRest_k(native_kernel, m / 16, n / 4, k, n % 4); |
| 76 | + return; |
| 77 | + } |
79 | 78 |
|
80 | | - if (m >= 16 && n >= 4 && n % 4 == 0) |
81 | | - { |
82 | | - // At this point m % 16 != 0 |
83 | | - kernels::matmul_16mRest_4n_k(native_kernel, m / 16, n / 4, k, m % 16); |
84 | | - return; |
85 | | - } |
| 79 | + if (m >= 16 && n >= 4 && n % 4 == 0) |
| 80 | + { |
| 81 | + // At this point m % 16 != 0 |
| 82 | + kernels::matmul_16mRest_4n_k(native_kernel, m / 16, n / 4, k, m % 16); |
| 83 | + return; |
| 84 | + } |
86 | 85 | } |
87 | | - |
0 commit comments