@@ -48,93 +48,86 @@ __attribute__((weak)) void print_verbose_header() {}
4848} // namespace dnnl
4949
5050static constexpr int PALETTE_SIZE = 64 ;
51+ static constexpr int DEFAULT_KERNEL_SIZE = 1024 ;
5152
5253using read_lock_guard_t = std::shared_lock<std::shared_mutex>;
5354using write_lock_guard_t = std::unique_lock<std::shared_mutex>;
5455static std::shared_mutex g_brgemm_lock;
5556
5657struct brgemm_cache_info_t {
57- std::shared_ptr< brgemm_desc_t > desc;
58+ brgemm_desc_t desc;
5859 brgemm_kernel_t *kernel;
59- std::shared_ptr <char []> palette;
60+ std::unique_ptr <char []> palette;
6061};
6162
62- static std::vector<brgemm_cache_info_t > g_cache;
63+ static std::vector<brgemm_cache_info_t > g_cache (DEFAULT_KERNEL_SIZE);
64+ static int64_t kernel_id = -1 ;
6365
6466// TODO(haixin): use syscall to determine page size?
6567static constexpr size_t SCRATCH_SIZE = 2 * 4096 ;
6668// TODO(haixin): need to use custom thread management for scratch in the future?
6769static thread_local char scratch[SCRATCH_SIZE] = {0 };
6870
69- static std::vector<brgemm_cache_info_t > &get_tl_cache () {
70- thread_local std::vector<brgemm_cache_info_t > tl_cache;
71- return tl_cache;
72- }
73-
7471extern " C" {
7572
7673int64_t dnnl_brgemm_dispatch (int64_t M, int64_t N, int64_t K, int64_t LDA,
7774 int64_t LDB, int64_t LDC, int64_t stride_a,
7875 int64_t stride_b, float beta, int64_t dtypeA,
7976 int64_t dtypeB) {
80- std::shared_ptr<brgemm_desc_t > desc_ptr = std::make_shared<brgemm_desc_t >();
81- brgemm_desc_t *desc = desc_ptr.get ();
82- brgemm_kernel_t *kernel;
8377 auto dnnl_dtypeA = static_cast <dnnl_data_type_t >(dtypeA);
8478 auto dnnl_dtypeB = static_cast <dnnl_data_type_t >(dtypeB);
8579 int64_t dtypeA_size = dnnl::impl::types::data_type_size (dnnl_dtypeA);
8680 int64_t dtypeB_size = dnnl::impl::types::data_type_size (dnnl_dtypeB);
8781 brgemm_strides_t stride_info{stride_a * dtypeA_size, stride_b * dtypeB_size};
8882
83+ write_lock_guard_t g (g_brgemm_lock);
84+ kernel_id++;
85+
86+ if (kernel_id >= DEFAULT_KERNEL_SIZE) {
87+ if (kernel_id >= (int64_t )g_cache.size ()) {
88+ g_cache.resize (kernel_id + 1 );
89+ }
90+ }
91+
8992 dnnl::impl::status_t status = brgemm_desc_init (
90- desc, cpu_isa_t ::isa_undef, brgemm_batch_kind_t ::brgemm_strd, dnnl_dtypeA ,
91- dnnl_dtypeB, /* transA= */ false , /* transB= */ false ,
92- brgemm_layout_t ::brgemm_row_major, 1 . 0f , beta, LDA, LDB, LDC, M, N, K ,
93- &stride_info);
93+ &g_cache[kernel_id]. desc , cpu_isa_t ::isa_undef,
94+ brgemm_batch_kind_t ::brgemm_strd, dnnl_dtypeA, dnnl_dtypeB ,
95+ /* transA= */ false , /* transB= */ false , brgemm_layout_t ::brgemm_row_major ,
96+ 1 . 0f , beta, LDA, LDB, LDC, M, N, K, &stride_info);
9497 assert (status == dnnl::impl::status::success &&
9598 " Failed to initialize BRGEMM descriptor" );
9699
97- status = brgemm_kernel_create (&kernel, *desc);
100+ status =
101+ brgemm_kernel_create (&g_cache[kernel_id].kernel , g_cache[kernel_id].desc );
98102 assert (status == dnnl::impl::status::success &&
99103 " Failed to JIT BRGEMM kernel" );
100104
101105 brgemm_attr_t dnnl_attrs;
102- brgemm_desc_set_attr (desc, dnnl_attrs);
103-
104- // TODO(haixin): Reuse identical palettes across kernels
105- std::shared_ptr<char []> palette_buffer;
106- if (desc->is_tmm ) {
107- palette_buffer.reset (new char [PALETTE_SIZE]);
108- dnnl::impl::status_t status =
109- brgemm_init_tiles (*desc, palette_buffer.get ());
106+ brgemm_desc_set_attr (&g_cache[kernel_id].desc , dnnl_attrs);
107+
108+ if (g_cache[kernel_id].desc .is_tmm ) {
109+ g_cache[kernel_id].palette .reset (new char [PALETTE_SIZE]);
110+ status = brgemm_init_tiles (g_cache[kernel_id].desc ,
111+ g_cache[kernel_id].palette .get ());
110112 assert (status == dnnl::impl::status::success &&
111113 " Failed to initialize palette for BRGEMM" );
112114 }
113115
114- write_lock_guard_t g (g_brgemm_lock);
115- g_cache.push_back (brgemm_cache_info_t {desc_ptr, kernel, palette_buffer});
116- return g_cache.size () - 1 ;
116+ return kernel_id;
117117}
118118
119119void dnnl_brgemm_tileconfig (int64_t kernel_idx) {
120- assert (kernel_idx >= 0 && " Invalid kernel handler" );
121- auto &tl_cache = get_tl_cache ();
122- if (kernel_idx >= (int64_t )tl_cache.size () ||
123- tl_cache[kernel_idx].kernel == nullptr ) {
124- read_lock_guard_t g (g_brgemm_lock);
125- assert (kernel_idx < (int64_t )g_cache.size () && " Invalid kernel handler" );
126- if (kernel_idx >= (int64_t )tl_cache.size ()) {
127- tl_cache.resize (kernel_idx + 1 );
128- }
129- tl_cache[kernel_idx] = g_cache[kernel_idx];
120+ std::unique_ptr<read_lock_guard_t > lock_guard;
121+ if (kernel_idx >= DEFAULT_KERNEL_SIZE) {
122+ lock_guard = std::make_unique<read_lock_guard_t >(g_brgemm_lock);
130123 }
131- brgemm_desc_t *desc = tl_cache[ kernel_idx]. desc . get ();
132- char *palette_buffer = tl_cache[kernel_idx]. palette . get ( );
133-
134- if (!desc-> is_tmm ) {
124+ assert (kernel_idx >= 0 && kernel_idx < ( int64_t )g_cache. size () &&
125+ " Invalid kernel handler " );
126+ brgemm_desc_t &desc = g_cache[kernel_idx]. desc ;
127+ if (!desc. is_tmm ) {
135128 return ;
136129 }
137-
130+ char *palette_buffer = g_cache[kernel_idx]. palette . get ();
138131 assert (palette_buffer != nullptr && " Invalid palette for BRGEMM kernel" );
139132 amx_tile_configure (palette_buffer);
140133}
@@ -150,35 +143,26 @@ void dnnl_brgemm_tilerelease() {
150143void dnnl_brgemm_execute (int64_t kernel_idx, void *A, uint64_t A_offset,
151144 void *B, uint64_t B_offset, void *C, uint64_t C_offset,
152145 int num) {
153- auto &tl_cache = get_tl_cache ();
154- if (kernel_idx >= (int64_t )tl_cache.size () ||
155- tl_cache[kernel_idx].kernel == nullptr ) {
156- read_lock_guard_t g (g_brgemm_lock);
157- assert (kernel_idx < (int64_t )g_cache.size () && " Invalid kernel handler" );
158- if (kernel_idx >= (int64_t )tl_cache.size ()) {
159- tl_cache.resize (kernel_idx + 1 );
160- }
161- tl_cache[kernel_idx] = g_cache[kernel_idx];
146+ std::unique_ptr<read_lock_guard_t > lock_guard;
147+ if (kernel_idx >= DEFAULT_KERNEL_SIZE) {
148+ lock_guard = std::make_unique<read_lock_guard_t >(g_brgemm_lock);
162149 }
163- brgemm_kernel_t *kernel = tl_cache[kernel_idx].kernel ;
164- brgemm_desc_t *desc_ptr = tl_cache[kernel_idx].desc .get ();
165-
150+ assert (kernel_idx >= 0 && kernel_idx < (int64_t )g_cache.size () &&
151+ " Invalid kernel handler" );
152+ brgemm_desc_t &desc = g_cache[kernel_idx].desc ;
153+ brgemm_kernel_t *kernel = g_cache[kernel_idx].kernel ;
166154 assert (kernel && " Invalid brgemm kernel pointer" );
167- assert (desc_ptr && " Invalid brgemm descriptor pointer" );
168-
169155 size_t A_offset_in_bytes =
170- dnnl::impl::types::data_type_size (desc_ptr-> dt_a ) * A_offset;
156+ dnnl::impl::types::data_type_size (desc. dt_a ) * A_offset;
171157 size_t B_offset_in_bytes =
172- dnnl::impl::types::data_type_size (desc_ptr-> dt_b ) * B_offset;
158+ dnnl::impl::types::data_type_size (desc. dt_b ) * B_offset;
173159 size_t C_offset_in_bytes =
174- dnnl::impl::types::data_type_size (desc_ptr->dt_c ) * C_offset;
175-
176- char *A_arith = (char *)A;
177- char *B_arith = (char *)B;
178- char *C_arith = (char *)C;
179- brgemm_kernel_execute (kernel, num, (void *)(A_arith + A_offset_in_bytes),
180- (void *)(B_arith + B_offset_in_bytes), nullptr ,
181- (void *)(C_arith + C_offset_in_bytes), (void *)scratch);
160+ dnnl::impl::types::data_type_size (desc.dt_c ) * C_offset;
161+ char *A_arith = static_cast <char *>(A) + A_offset_in_bytes;
162+ char *B_arith = static_cast <char *>(B) + B_offset_in_bytes;
163+ char *C_arith = static_cast <char *>(C) + C_offset_in_bytes;
164+ brgemm_kernel_execute (kernel, num, A_arith, B_arith, nullptr , C_arith,
165+ scratch);
182166}
183167}
184168
0 commit comments