@@ -48,93 +48,90 @@ __attribute__((weak)) void print_verbose_header() {}
4848} // namespace dnnl
4949
5050static constexpr int PALETTE_SIZE = 64 ;
51+ static constexpr int DEFAULT_KERNEL_SIZE = 1024 ;
5152
5253using read_lock_guard_t = std::shared_lock<std::shared_mutex>;
5354using write_lock_guard_t = std::unique_lock<std::shared_mutex>;
5455static std::shared_mutex g_brgemm_lock;
5556
5657struct brgemm_cache_info_t {
57- std::shared_ptr< brgemm_desc_t > desc;
58+ brgemm_desc_t desc;
5859 brgemm_kernel_t *kernel;
59- std::shared_ptr <char []> palette;
60+ std::unique_ptr <char []> palette;
6061};
6162
62- static std::vector<brgemm_cache_info_t > g_cache;
63+ static std::vector<brgemm_cache_info_t > g_cache (DEFAULT_KERNEL_SIZE);
64+ static int64_t kernel_id = -1 ;
6365
6466// TODO(haixin): use syscall to determine page size?
6567static constexpr size_t SCRATCH_SIZE = 2 * 4096 ;
6668// TODO(haixin): need to use custom thread management for scratch in the future?
6769static thread_local char scratch[SCRATCH_SIZE] = {0 };
6870
69- static std::vector<brgemm_cache_info_t > &get_tl_cache () {
70- thread_local std::vector<brgemm_cache_info_t > tl_cache;
71- return tl_cache;
72- }
73-
7471extern " C" {
7572
7673int64_t dnnl_brgemm_dispatch (int64_t M, int64_t N, int64_t K, int64_t LDA,
7774 int64_t LDB, int64_t LDC, int64_t stride_a,
7875 int64_t stride_b, float beta, int64_t dtypeA,
7976 int64_t dtypeB) {
80- std::shared_ptr<brgemm_desc_t > desc_ptr = std::make_shared<brgemm_desc_t >();
81- brgemm_desc_t *desc = desc_ptr.get ();
82- brgemm_kernel_t *kernel;
8377 auto dnnl_dtypeA = static_cast <dnnl_data_type_t >(dtypeA);
8478 auto dnnl_dtypeB = static_cast <dnnl_data_type_t >(dtypeB);
8579 int64_t dtypeA_size = dnnl::impl::types::data_type_size (dnnl_dtypeA);
8680 int64_t dtypeB_size = dnnl::impl::types::data_type_size (dnnl_dtypeB);
8781 brgemm_strides_t stride_info{stride_a * dtypeA_size, stride_b * dtypeB_size};
8882
83+ write_lock_guard_t g (g_brgemm_lock);
84+ kernel_id++;
85+
86+ if (kernel_id >= DEFAULT_KERNEL_SIZE) {
87+ if (kernel_id >= g_cache.size ()) {
88+ g_cache.resize (kernel_id + 1 );
89+ }
90+ }
91+
8992 dnnl::impl::status_t status = brgemm_desc_init (
90- desc, cpu_isa_t ::isa_undef, brgemm_batch_kind_t ::brgemm_strd, dnnl_dtypeA,
91- dnnl_dtypeB, /* transA=*/ false , /* transB=*/ false ,
92- brgemm_layout_t ::brgemm_row_major, 1 .0f , beta, LDA, LDB, LDC, M, N, K,
93- &stride_info);
94- assert (status == dnnl::impl::status::success &&
95- " Failed to initialize BRGEMM descriptor" );
93+ &g_cache[kernel_id].desc , cpu_isa_t ::isa_undef,
94+ brgemm_batch_kind_t ::brgemm_strd, dnnl_dtypeA, dnnl_dtypeB,
95+ /* transA=*/ false , /* transB=*/ false , brgemm_layout_t ::brgemm_row_major,
96+ 1 .0f , beta, LDA, LDB, LDC, M, N, K, &stride_info);
97+ if (status != dnnl::impl::status::success) {
98+ return -1 ;
99+ }
96100
97- status = brgemm_kernel_create (&kernel, *desc);
98- assert (status == dnnl::impl::status::success &&
99- " Failed to JIT BRGEMM kernel" );
101+ status =
102+ brgemm_kernel_create (&g_cache[kernel_id].kernel , g_cache[kernel_id].desc );
103+ if (status != dnnl::impl::status::success) {
104+ return -1 ;
105+ }
100106
101107 brgemm_attr_t dnnl_attrs;
102- brgemm_desc_set_attr (desc, dnnl_attrs);
103-
104- // TODO(haixin): Reuse identical palettes across kernels
105- std::shared_ptr<char []> palette_buffer;
106- if (desc->is_tmm ) {
107- palette_buffer.reset (new char [PALETTE_SIZE]);
108- dnnl::impl::status_t status =
109- brgemm_init_tiles (*desc, palette_buffer.get ());
110- assert (status == dnnl::impl::status::success &&
111- " Failed to initialize palette for BRGEMM" );
108+ brgemm_desc_set_attr (&g_cache[kernel_id].desc , dnnl_attrs);
109+
110+ if (g_cache[kernel_id].desc .is_tmm ) {
111+ g_cache[kernel_id].palette .reset (new char [PALETTE_SIZE]);
112+ status = brgemm_init_tiles (g_cache[kernel_id].desc ,
113+ g_cache[kernel_id].palette .get ());
114+ if (status != dnnl::impl::status::success) {
115+ return -1 ;
116+ }
112117 }
113118
114- write_lock_guard_t g (g_brgemm_lock);
115- g_cache.push_back (brgemm_cache_info_t {desc_ptr, kernel, palette_buffer});
116- return g_cache.size () - 1 ;
119+ return kernel_id;
117120}
118121
119122void dnnl_brgemm_tileconfig (int64_t kernel_idx) {
120- assert (kernel_idx >= 0 && " Invalid kernel handler" );
121- auto &tl_cache = get_tl_cache ();
122- if (kernel_idx >= (int64_t )tl_cache.size () ||
123- tl_cache[kernel_idx].kernel == nullptr ) {
124- read_lock_guard_t g (g_brgemm_lock);
125- assert (kernel_idx < (int64_t )g_cache.size () && " Invalid kernel handler" );
126- if (kernel_idx >= (int64_t )tl_cache.size ()) {
127- tl_cache.resize (kernel_idx + 1 );
128- }
129- tl_cache[kernel_idx] = g_cache[kernel_idx];
123+ // Declare the lock guard outside the if block to extend its lifetime
124+ std::unique_ptr<read_lock_guard_t > lock_guard;
125+ if (kernel_idx >= DEFAULT_KERNEL_SIZE) {
126+ lock_guard = std::make_unique<read_lock_guard_t >(g_brgemm_lock);
130127 }
131- brgemm_desc_t *desc = tl_cache[ kernel_idx]. desc . get ();
132- char *palette_buffer = tl_cache[kernel_idx]. palette . get ( );
133-
134- if (!desc-> is_tmm ) {
128+ assert (kernel_idx >= 0 && kernel_idx < ( int64_t )g_cache. size () &&
129+ " Invalid kernel handler " );
130+ brgemm_desc_t &desc = g_cache[kernel_idx]. desc ;
131+ if (!desc. is_tmm ) {
135132 return ;
136133 }
137-
134+ char *palette_buffer = g_cache[kernel_idx]. palette . get ();
138135 assert (palette_buffer != nullptr && " Invalid palette for BRGEMM kernel" );
139136 amx_tile_configure (palette_buffer);
140137}
@@ -150,35 +147,27 @@ void dnnl_brgemm_tilerelease() {
150147void dnnl_brgemm_execute (int64_t kernel_idx, void *A, uint64_t A_offset,
151148 void *B, uint64_t B_offset, void *C, uint64_t C_offset,
152149 int num) {
153- auto &tl_cache = get_tl_cache ();
154- if (kernel_idx >= (int64_t )tl_cache.size () ||
155- tl_cache[kernel_idx].kernel == nullptr ) {
156- read_lock_guard_t g (g_brgemm_lock);
157- assert (kernel_idx < (int64_t )g_cache.size () && " Invalid kernel handler" );
158- if (kernel_idx >= (int64_t )tl_cache.size ()) {
159- tl_cache.resize (kernel_idx + 1 );
160- }
161- tl_cache[kernel_idx] = g_cache[kernel_idx];
150+ // Acquire the lock only if needed
151+ std::unique_ptr<read_lock_guard_t > lock_guard;
152+ if (kernel_idx >= DEFAULT_KERNEL_SIZE) {
153+ lock_guard = std::make_unique<read_lock_guard_t >(g_brgemm_lock);
162154 }
163- brgemm_kernel_t *kernel = tl_cache[kernel_idx].kernel ;
164- brgemm_desc_t *desc_ptr = tl_cache[kernel_idx].desc .get ();
165-
155+ assert (kernel_idx >= 0 && kernel_idx < (int64_t )g_cache.size () &&
156+ " Invalid kernel handler" );
157+ brgemm_desc_t &desc = g_cache[kernel_idx].desc ;
158+ brgemm_kernel_t *kernel = g_cache[kernel_idx].kernel ;
166159 assert (kernel && " Invalid brgemm kernel pointer" );
167- assert (desc_ptr && " Invalid brgemm descriptor pointer" );
168-
169160 size_t A_offset_in_bytes =
170- dnnl::impl::types::data_type_size (desc_ptr-> dt_a ) * A_offset;
161+ dnnl::impl::types::data_type_size (desc. dt_a ) * A_offset;
171162 size_t B_offset_in_bytes =
172- dnnl::impl::types::data_type_size (desc_ptr-> dt_b ) * B_offset;
163+ dnnl::impl::types::data_type_size (desc. dt_b ) * B_offset;
173164 size_t C_offset_in_bytes =
174- dnnl::impl::types::data_type_size (desc_ptr->dt_c ) * C_offset;
175-
176- char *A_arith = (char *)A;
177- char *B_arith = (char *)B;
178- char *C_arith = (char *)C;
179- brgemm_kernel_execute (kernel, num, (void *)(A_arith + A_offset_in_bytes),
180- (void *)(B_arith + B_offset_in_bytes), nullptr ,
181- (void *)(C_arith + C_offset_in_bytes), (void *)scratch);
165+ dnnl::impl::types::data_type_size (desc.dt_c ) * C_offset;
166+ char *A_arith = static_cast <char *>(A) + A_offset_in_bytes;
167+ char *B_arith = static_cast <char *>(B) + B_offset_in_bytes;
168+ char *C_arith = static_cast <char *>(C) + C_offset_in_bytes;
169+ brgemm_kernel_execute (kernel, num, A_arith, B_arith, nullptr , C_arith,
170+ scratch);
182171}
183172}
184173
0 commit comments