Skip to content

Commit 2c6dd4c

Browse files
committed
style: format all .ccp and .h
1 parent 456cf4f commit 2c6dd4c

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

92 files changed

+9128
-9270
lines changed

.clang-format

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
BasedOnStyle: LLVM
22
IndentWidth: 2
33
ContinuationIndentWidth: 2
4+
SpacesBeforeTrailingComments: 2
45
UseTab: Never
56
BreakBeforeBraces: Allman
67
AllowShortIfStatementsOnASingleLine: false
78
AllowShortLoopsOnASingleLine: false
89
AllowShortFunctionsOnASingleLine: None
9-
SpacesBeforeParens: ControlStatements
10+
SpaceBeforeParens: ControlStatements
1011
AlwaysBreakAfterReturnType: None
1112
BinPackArguments: true
1213
BinPackParameters: true

src/main/Brgemm.cpp

Lines changed: 63 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -1,87 +1,85 @@
11
#include "Brgemm.h"
2-
#include "kernels/matmuls_all.h"
32
#include "Kernel.h"
4-
#include <stdexcept>
3+
#include "kernels/matmuls_all.h"
54
#include <format>
5+
#include <stdexcept>
66

7-
mini_jit::Brgemm::error_t mini_jit::Brgemm::generate(uint32_t m, uint32_t n, uint32_t k, uint32_t br_size,
8-
uint32_t trans_a, uint32_t trans_b, uint32_t trans_c, dtype_t dtype)
7+
mini_jit::Brgemm::error_t mini_jit::Brgemm::generate(uint32_t m, uint32_t n, uint32_t k, uint32_t br_size, uint32_t trans_a,
8+
uint32_t trans_b, uint32_t trans_c, dtype_t dtype)
99
{
10-
if (dtype != dtype_t::fp32)
11-
{
12-
return error_t::err_wrong_dtype;
13-
}
14-
if (m % 16 != 0 || (n < 4))
15-
{
16-
return error_t::err_wrong_dimension;
17-
}
18-
if ((trans_a + trans_b + trans_c) != 0)
19-
{
20-
return error_t::err_row_major_order_not_supported;
21-
}
22-
if (br_size != 1)
23-
{
24-
return error_t::err_batch_reduce_size_not_supported;
25-
}
10+
if (dtype != dtype_t::fp32)
11+
{
12+
return error_t::err_wrong_dtype;
13+
}
14+
if (m % 16 != 0 || (n < 4))
15+
{
16+
return error_t::err_wrong_dimension;
17+
}
18+
if ((trans_a + trans_b + trans_c) != 0)
19+
{
20+
return error_t::err_row_major_order_not_supported;
21+
}
22+
if (br_size != 1)
23+
{
24+
return error_t::err_batch_reduce_size_not_supported;
25+
}
2626

27-
if (br_size == 1 && (trans_a + trans_b + trans_c) == 0 && dtype == dtype_t::fp32)
28-
{
29-
fill_with_matmuls_no_batch_dim_column_major_fp32(m, n, k);
30-
}
31-
else
32-
{
33-
throw std::logic_error(std::format(
34-
"Unhandled parameter combination found: m='{}', n='{}', k='{}', br_size='{}', trans_a='{}', trans_b='{}', "
35-
"trans_c = '{}', dtype = '{}'", m, n, k, br_size, trans_a, trans_b, trans_c, static_cast<int32_t>(dtype)));
36-
}
27+
if (br_size == 1 && (trans_a + trans_b + trans_c) == 0 && dtype == dtype_t::fp32)
28+
{
29+
fill_with_matmuls_no_batch_dim_column_major_fp32(m, n, k);
30+
}
31+
else
32+
{
33+
throw std::logic_error(
34+
std::format("Unhandled parameter combination found: m='{}', n='{}', k='{}', br_size='{}', trans_a='{}', trans_b='{}', "
35+
"trans_c = '{}', dtype = '{}'",
36+
m, n, k, br_size, trans_a, trans_b, trans_c, static_cast<int32_t>(dtype)));
37+
}
3738

39+
native_kernel.set_kernel();
40+
kernel = reinterpret_cast<kernel_t>(const_cast<void *>(native_kernel.get_kernel())); // Properly cast from const void* to kernel_t
3841

39-
40-
native_kernel.set_kernel();
41-
kernel = reinterpret_cast<kernel_t>(const_cast<void*>(native_kernel.get_kernel())); // Properly cast from const void* to kernel_t
42-
43-
return error_t::success;
42+
return error_t::success;
4443
}
4544

4645
mini_jit::Brgemm::kernel_t mini_jit::Brgemm::get_kernel() const
4746
{
48-
return kernel;
47+
return kernel;
4948
}
5049

5150
void mini_jit::Brgemm::fill_with_matmuls_no_batch_dim_column_major_fp32(uint32_t m, uint32_t n, uint32_t k)
5251
{
53-
// Always sort from the specific to the more general case
52+
// Always sort from the specific to the more general case
5453

55-
if (m == 16 && n == 6 && k == 1)
56-
{
57-
kernels::matmul_16_6_1(native_kernel);
58-
return;
59-
}
54+
if (m == 16 && n == 6 && k == 1)
55+
{
56+
kernels::matmul_16_6_1(native_kernel);
57+
return;
58+
}
6059

61-
if (m == 16 && n == 6)
62-
{
63-
kernels::matmul_16_6_k(native_kernel, k);
64-
return;
65-
}
60+
if (m == 16 && n == 6)
61+
{
62+
kernels::matmul_16_6_k(native_kernel, k);
63+
return;
64+
}
6665

67-
if (m >= 16 && m % 16 == 0 && n >= 4 && n % 4 == 0)
68-
{
69-
kernels::matmul_16m_4n_k(native_kernel, m / 16, n / 4, k);
70-
return;
71-
}
66+
if (m >= 16 && m % 16 == 0 && n >= 4 && n % 4 == 0)
67+
{
68+
kernels::matmul_16m_4n_k(native_kernel, m / 16, n / 4, k);
69+
return;
70+
}
7271

73-
if (m >= 16 && m % 16 == 0 && n >= 4)
74-
{
75-
// At this point n % 4 != 0
76-
kernels::matmul_16m_4nRest_k(native_kernel, m / 16, n / 4, k, n % 4);
77-
return;
78-
}
72+
if (m >= 16 && m % 16 == 0 && n >= 4)
73+
{
74+
// At this point n % 4 != 0
75+
kernels::matmul_16m_4nRest_k(native_kernel, m / 16, n / 4, k, n % 4);
76+
return;
77+
}
7978

80-
if (m >= 16 && n >= 4 && n % 4 == 0)
81-
{
82-
// At this point m % 16 != 0
83-
kernels::matmul_16mRest_4n_k(native_kernel, m / 16, n / 4, k, m % 16);
84-
return;
85-
}
79+
if (m >= 16 && n >= 4 && n % 4 == 0)
80+
{
81+
// At this point m % 16 != 0
82+
kernels::matmul_16mRest_4n_k(native_kernel, m / 16, n / 4, k, m % 16);
83+
return;
84+
}
8685
}
87-

src/main/Brgemm.h

Lines changed: 6 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
#ifndef MINI_JIT_BRGEMM_H
22
#define MINI_JIT_BRGEMM_H
33

4-
#include <cstdint>
54
#include "Kernel.h"
5+
#include <cstdint>
66

77
namespace mini_jit
88
{
9-
class Brgemm;
9+
class Brgemm;
1010
}
1111

1212
class mini_jit::Brgemm
@@ -25,14 +25,8 @@ class mini_jit::Brgemm
2525
* - br_stride_a: stride between two A matrices (in elements, not bytes).
2626
* - br_stride_b: stride between two B matrices (in elements, not bytes).
2727
*/
28-
using kernel_t = void (*)(void const* a,
29-
void const* b,
30-
void* c,
31-
int64_t lda,
32-
int64_t ldb,
33-
int64_t ldc,
34-
int64_t br_stride_a,
35-
int64_t br_stride_b);
28+
using kernel_t = void (*)(void const *a, void const *b, void *c, int64_t lda, int64_t ldb, int64_t ldc, int64_t br_stride_a,
29+
int64_t br_stride_b);
3630

3731
/// data type
3832
enum class dtype_t : uint32_t
@@ -63,14 +57,8 @@ class mini_jit::Brgemm
6357
* @param dtype data type of the matrices.
6458
* @return error_t::success on success, another error_t value otherwise.
6559
**/
66-
error_t generate(uint32_t m,
67-
uint32_t n,
68-
uint32_t k,
69-
uint32_t br_size,
70-
uint32_t trans_a,
71-
uint32_t trans_b,
72-
uint32_t trans_c,
73-
dtype_t dtype);
60+
error_t generate(uint32_t m, uint32_t n, uint32_t k, uint32_t br_size, uint32_t trans_a, uint32_t trans_b, uint32_t trans_c,
61+
dtype_t dtype);
7462

7563
/**
7664
* @brief Get the generated kernel: C += sum_i(A_i * B_i).

0 commit comments

Comments
 (0)