@@ -48,6 +48,8 @@ __attribute__((weak)) void print_verbose_header() {}
4848} // namespace dnnl
4949
5050static constexpr int PALETTE_SIZE = 64 ;
51+ static constexpr int DEFAULT_KERNEL_SIZE = 1024 ;
52+ static constexpr int MAX_KERNEL_SIZE = 2048 ;
5153
5254using read_lock_guard_t = std::shared_lock<std::shared_mutex>;
5355using write_lock_guard_t = std::unique_lock<std::shared_mutex>;
@@ -56,81 +58,78 @@ static std::shared_mutex g_brgemm_lock;
5658struct brgemm_cache_info_t {
5759 brgemm_desc_t desc;
5860 brgemm_kernel_t *kernel;
59- std::shared_ptr <char []> palette;
61+ std::unique_ptr <char []> palette;
6062};
6163
62- static std::vector<brgemm_cache_info_t > g_cache;
64+ static std::vector<brgemm_cache_info_t > g_cache (DEFAULT_KERNEL_SIZE);
65+ static int64_t g_kernel_id = -1 ;
6366
6467// TODO(haixin): use syscall to determine page size?
6568static constexpr size_t SCRATCH_SIZE = 2 * 4096 ;
6669// TODO(haixin): need to use custom thread management for scratch in the future?
6770static thread_local char scratch[SCRATCH_SIZE] = {0 };
6871
69- static std::unordered_map<int64_t , brgemm_cache_info_t > &get_tl_cache () {
70- thread_local std::unordered_map<int64_t , brgemm_cache_info_t > tl_cache;
71- return tl_cache;
72- }
73-
7472extern " C" {
7573
7674int64_t dnnl_brgemm_dispatch (int64_t M, int64_t N, int64_t K, int64_t LDA,
7775 int64_t LDB, int64_t LDC, int64_t stride_a,
7876 int64_t stride_b, float beta, int64_t dtypeA,
7977 int64_t dtypeB) {
80- brgemm_desc_t desc;
81- brgemm_kernel_t *kernel;
82-
8378 auto dnnl_dtypeA = static_cast <dnnl_data_type_t >(dtypeA);
8479 auto dnnl_dtypeB = static_cast <dnnl_data_type_t >(dtypeB);
8580 int64_t dtypeA_size = dnnl::impl::types::data_type_size (dnnl_dtypeA);
8681 int64_t dtypeB_size = dnnl::impl::types::data_type_size (dnnl_dtypeB);
8782 brgemm_strides_t stride_info{stride_a * dtypeA_size, stride_b * dtypeB_size};
8883
84+ write_lock_guard_t g (g_brgemm_lock);
85+ g_kernel_id++;
86+ assert (g_kernel_id < MAX_KERNEL_SIZE &&
87+ " Too many brgemm kernels are created" );
88+ if (g_kernel_id >= DEFAULT_KERNEL_SIZE) {
89+ if (g_kernel_id >= (int64_t )g_cache.size ()) {
90+ g_cache.resize (g_kernel_id + 1 );
91+ }
92+ }
93+
8994 dnnl::impl::status_t status = brgemm_desc_init (
90- &desc, cpu_isa_t ::isa_undef, brgemm_batch_kind_t ::brgemm_strd ,
91- dnnl_dtypeA, dnnl_dtypeB, /* transA= */ false , /* transB= */ false ,
92- brgemm_layout_t ::brgemm_row_major, 1 . 0f , beta, LDA, LDB, LDC, M, N, K ,
93- &stride_info);
95+ &g_cache[g_kernel_id]. desc , cpu_isa_t ::isa_undef,
96+ brgemm_batch_kind_t ::brgemm_strd, dnnl_dtypeA, dnnl_dtypeB ,
97+ /* transA= */ false , /* transB= */ false , brgemm_layout_t ::brgemm_row_major ,
98+ 1 . 0f , beta, LDA, LDB, LDC, M, N, K, &stride_info);
9499 assert (status == dnnl::impl::status::success &&
95100 " Failed to initialize BRGEMM descriptor" );
96101
97- status = brgemm_kernel_create (&kernel, desc);
102+ status = brgemm_kernel_create (&g_cache[g_kernel_id].kernel ,
103+ g_cache[g_kernel_id].desc );
98104 assert (status == dnnl::impl::status::success &&
99105 " Failed to JIT BRGEMM kernel" );
100106
101107 brgemm_attr_t dnnl_attrs;
102- brgemm_desc_set_attr (&desc, dnnl_attrs);
108+ brgemm_desc_set_attr (&g_cache[g_kernel_id]. desc , dnnl_attrs);
103109
104- // TODO(haixin): Reuse identical palettes across kernels
105- std::shared_ptr<char []> palette_buffer;
106- if (desc.is_tmm ) {
107- palette_buffer.reset (new char [PALETTE_SIZE]);
108- dnnl::impl::status_t status = brgemm_init_tiles (desc, palette_buffer.get ());
110+ if (g_cache[g_kernel_id].desc .is_tmm ) {
111+ g_cache[g_kernel_id].palette .reset (new char [PALETTE_SIZE]);
112+ status = brgemm_init_tiles (g_cache[g_kernel_id].desc ,
113+ g_cache[g_kernel_id].palette .get ());
109114 assert (status == dnnl::impl::status::success &&
110115 " Failed to initialize palette for BRGEMM" );
111116 }
112117
113- write_lock_guard_t g (g_brgemm_lock);
114- g_cache.push_back (brgemm_cache_info_t {desc, kernel, palette_buffer});
115- return g_cache.size () - 1 ;
118+ return g_kernel_id;
116119}
117120
118121void dnnl_brgemm_tileconfig (int64_t kernel_idx) {
119- assert (kernel_idx >= 0 && " Invalid kernel handler" );
120- auto &tl_cache = get_tl_cache ();
121- auto it = tl_cache.find (kernel_idx);
122- if (it == tl_cache.end ()) {
123- read_lock_guard_t g (g_brgemm_lock);
124- assert (kernel_idx < (int64_t )g_cache.size () && " Invalid kernel handler" );
125- it = tl_cache.insert ({kernel_idx, g_cache[kernel_idx]}).first ;
122+ std::unique_ptr<read_lock_guard_t > lock_guard;
123+ if (kernel_idx >= DEFAULT_KERNEL_SIZE) {
124+ lock_guard = std::make_unique<read_lock_guard_t >(g_brgemm_lock);
126125 }
127- brgemm_desc_t &desc = it-> second . desc ;
128- char *palette_buffer = it-> second . palette . get ( );
129-
126+ assert (kernel_idx >= 0 && kernel_idx < ( int64_t )g_cache. size () &&
127+ " Invalid kernel handler " );
128+ brgemm_desc_t &desc = g_cache[kernel_idx]. desc ;
130129 if (!desc.is_tmm ) {
131130 return ;
132131 }
133-
132+ char *palette_buffer = g_cache[kernel_idx]. palette . get ();
134133 assert (palette_buffer != nullptr && " Invalid palette for BRGEMM kernel" );
135134 amx_tile_configure (palette_buffer);
136135}
@@ -146,35 +145,26 @@ void dnnl_brgemm_tilerelease() {
146145void dnnl_brgemm_execute (int64_t kernel_idx, void *A, uint64_t A_offset,
147146 void *B, uint64_t B_offset, void *C, uint64_t C_offset,
148147 int num) {
149- auto &tl_cache = get_tl_cache ();
150- if (tl_cache.find (kernel_idx) == tl_cache.end ()) {
151- read_lock_guard_t g (g_brgemm_lock);
152- assert (kernel_idx >= 0 && kernel_idx < (int64_t )g_cache.size () &&
153- " Invalid kernel handler" );
154- auto updated_cache =
155- tl_cache.insert (std::make_pair (kernel_idx, g_cache[kernel_idx]));
156- assert (updated_cache.second && " insert into thread local cache" );
148+ std::unique_ptr<read_lock_guard_t > lock_guard;
149+ if (kernel_idx >= DEFAULT_KERNEL_SIZE) {
150+ lock_guard = std::make_unique<read_lock_guard_t >(g_brgemm_lock);
157151 }
158- auto it = tl_cache. find (kernel_idx);
159- brgemm_kernel_t *kernel = it-> second . kernel ;
160- brgemm_desc_t *desc_ptr = &it-> second .desc ;
161-
152+ assert (kernel_idx >= 0 && kernel_idx < ( int64_t )g_cache. size () &&
153+ " Invalid kernel handler " ) ;
154+ brgemm_desc_t &desc = g_cache[kernel_idx] .desc ;
155+ brgemm_kernel_t *kernel = g_cache[kernel_idx]. kernel ;
162156 assert (kernel && " Invalid brgemm kernel pointer" );
163- assert (desc_ptr && " Invalid brgemm descriptor pointer" );
164-
165157 size_t A_offset_in_bytes =
166- dnnl::impl::types::data_type_size (desc_ptr-> dt_a ) * A_offset;
158+ dnnl::impl::types::data_type_size (desc. dt_a ) * A_offset;
167159 size_t B_offset_in_bytes =
168- dnnl::impl::types::data_type_size (desc_ptr-> dt_b ) * B_offset;
160+ dnnl::impl::types::data_type_size (desc. dt_b ) * B_offset;
169161 size_t C_offset_in_bytes =
170- dnnl::impl::types::data_type_size (desc_ptr->dt_c ) * C_offset;
171-
172- char *A_arith = (char *)A;
173- char *B_arith = (char *)B;
174- char *C_arith = (char *)C;
175- brgemm_kernel_execute (kernel, num, (void *)(A_arith + A_offset_in_bytes),
176- (void *)(B_arith + B_offset_in_bytes), nullptr ,
177- (void *)(C_arith + C_offset_in_bytes), (void *)scratch);
162+ dnnl::impl::types::data_type_size (desc.dt_c ) * C_offset;
163+ char *A_arith = static_cast <char *>(A) + A_offset_in_bytes;
164+ char *B_arith = static_cast <char *>(B) + B_offset_in_bytes;
165+ char *C_arith = static_cast <char *>(C) + C_offset_in_bytes;
166+ brgemm_kernel_execute (kernel, num, A_arith, B_arith, nullptr , C_arith,
167+ scratch);
178168}
179169}
180170
0 commit comments