@@ -54,7 +54,7 @@ using write_lock_guard_t = std::unique_lock<std::shared_mutex>;
5454static std::shared_mutex g_brgemm_lock;
5555
5656struct brgemm_cache_info_t {
57- brgemm_desc_t * desc;
57+ std::shared_ptr< brgemm_desc_t > desc;
5858 brgemm_kernel_t *kernel;
5959 std::shared_ptr<char []> palette;
6060};
@@ -70,14 +70,15 @@ static std::vector<brgemm_cache_info_t> &get_tl_cache() {
7070 thread_local std::vector<brgemm_cache_info_t > tl_cache;
7171 return tl_cache;
7272}
73- brgemm_desc_t desc;
7473
7574extern " C" {
7675
7776int64_t dnnl_brgemm_dispatch (int64_t M, int64_t N, int64_t K, int64_t LDA,
7877 int64_t LDB, int64_t LDC, int64_t stride_a,
7978 int64_t stride_b, float beta, int64_t dtypeA,
8079 int64_t dtypeB) {
80+ std::shared_ptr<brgemm_desc_t > desc_ptr = std::make_shared<brgemm_desc_t >();
81+ brgemm_desc_t *desc = desc_ptr.get ();
8182 brgemm_kernel_t *kernel;
8283 auto dnnl_dtypeA = static_cast <dnnl_data_type_t >(dtypeA);
8384 auto dnnl_dtypeB = static_cast <dnnl_data_type_t >(dtypeB);
@@ -86,31 +87,32 @@ int64_t dnnl_brgemm_dispatch(int64_t M, int64_t N, int64_t K, int64_t LDA,
8687 brgemm_strides_t stride_info{stride_a * dtypeA_size, stride_b * dtypeB_size};
8788
8889 dnnl::impl::status_t status = brgemm_desc_init (
89- & desc, cpu_isa_t ::isa_undef, brgemm_batch_kind_t ::brgemm_strd,
90- dnnl_dtypeA, dnnl_dtypeB, /* transA=*/ false , /* transB=*/ false ,
90+ desc, cpu_isa_t ::isa_undef, brgemm_batch_kind_t ::brgemm_strd, dnnl_dtypeA ,
91+ dnnl_dtypeB, /* transA=*/ false , /* transB=*/ false ,
9192 brgemm_layout_t ::brgemm_row_major, 1 .0f , beta, LDA, LDB, LDC, M, N, K,
9293 &stride_info);
9394 assert (status == dnnl::impl::status::success &&
9495 " Failed to initialize BRGEMM descriptor" );
9596
96- status = brgemm_kernel_create (&kernel, desc);
97+ status = brgemm_kernel_create (&kernel, * desc);
9798 assert (status == dnnl::impl::status::success &&
9899 " Failed to JIT BRGEMM kernel" );
99100
100101 brgemm_attr_t dnnl_attrs;
101- brgemm_desc_set_attr (& desc, dnnl_attrs);
102+ brgemm_desc_set_attr (desc, dnnl_attrs);
102103
103104 // TODO(haixin): Reuse identical palettes across kernels
104105 std::shared_ptr<char []> palette_buffer;
105- if (desc. is_tmm ) {
106+ if (desc-> is_tmm ) {
106107 palette_buffer.reset (new char [PALETTE_SIZE]);
107- dnnl::impl::status_t status = brgemm_init_tiles (desc, palette_buffer.get ());
108+ dnnl::impl::status_t status =
109+ brgemm_init_tiles (*desc, palette_buffer.get ());
108110 assert (status == dnnl::impl::status::success &&
109111 " Failed to initialize palette for BRGEMM" );
110112 }
111113
112114 write_lock_guard_t g (g_brgemm_lock);
113- g_cache.push_back (brgemm_cache_info_t {&desc , kernel, palette_buffer});
115+ g_cache.push_back (brgemm_cache_info_t {desc_ptr , kernel, palette_buffer});
114116 return g_cache.size () - 1 ;
115117}
116118
@@ -126,10 +128,10 @@ void dnnl_brgemm_tileconfig(int64_t kernel_idx) {
126128 }
127129 tl_cache[kernel_idx] = g_cache[kernel_idx];
128130 }
129- brgemm_desc_t & desc = * tl_cache[kernel_idx].desc ;
131+ brgemm_desc_t * desc = tl_cache[kernel_idx].desc . get () ;
130132 char *palette_buffer = tl_cache[kernel_idx].palette .get ();
131133
132- if (!desc. is_tmm ) {
134+ if (!desc-> is_tmm ) {
133135 return ;
134136 }
135137
@@ -159,7 +161,7 @@ void dnnl_brgemm_execute(int64_t kernel_idx, void *A, uint64_t A_offset,
159161 tl_cache[kernel_idx] = g_cache[kernel_idx];
160162 }
161163 brgemm_kernel_t *kernel = tl_cache[kernel_idx].kernel ;
162- brgemm_desc_t *desc_ptr = tl_cache[kernel_idx].desc ;
164+ brgemm_desc_t *desc_ptr = tl_cache[kernel_idx].desc . get () ;
163165
164166 assert (kernel && " Invalid brgemm kernel pointer" );
165167 assert (desc_ptr && " Invalid brgemm descriptor pointer" );
0 commit comments