@@ -70,14 +70,15 @@ static std::vector<brgemm_cache_info_t> &get_tl_cache() {
7070 thread_local std::vector<brgemm_cache_info_t > tl_cache;
7171 return tl_cache;
7272}
73- brgemm_desc_t desc;
7473
7574extern " C" {
7675
7776int64_t dnnl_brgemm_dispatch (int64_t M, int64_t N, int64_t K, int64_t LDA,
7877 int64_t LDB, int64_t LDC, int64_t stride_a,
7978 int64_t stride_b, float beta, int64_t dtypeA,
8079 int64_t dtypeB) {
80+ std::shared_ptr<brgemm_desc_t > desc_ptr = std::make_shared<brgemm_desc_t >();
81+ brgemm_desc_t *desc = desc_ptr.get ();
8182 brgemm_kernel_t *kernel;
8283 auto dnnl_dtypeA = static_cast <dnnl_data_type_t >(dtypeA);
8384 auto dnnl_dtypeB = static_cast <dnnl_data_type_t >(dtypeB);
@@ -86,31 +87,32 @@ int64_t dnnl_brgemm_dispatch(int64_t M, int64_t N, int64_t K, int64_t LDA,
8687 brgemm_strides_t stride_info{stride_a * dtypeA_size, stride_b * dtypeB_size};
8788
8889 dnnl::impl::status_t status = brgemm_desc_init (
89- & desc, cpu_isa_t ::isa_undef, brgemm_batch_kind_t ::brgemm_strd,
90- dnnl_dtypeA, dnnl_dtypeB, /* transA=*/ false , /* transB=*/ false ,
90+ desc, cpu_isa_t ::isa_undef, brgemm_batch_kind_t ::brgemm_strd, dnnl_dtypeA ,
91+ dnnl_dtypeB, /* transA=*/ false , /* transB=*/ false ,
9192 brgemm_layout_t ::brgemm_row_major, 1 .0f , beta, LDA, LDB, LDC, M, N, K,
9293 &stride_info);
9394 assert (status == dnnl::impl::status::success &&
9495 " Failed to initialize BRGEMM descriptor" );
9596
96- status = brgemm_kernel_create (&kernel, desc);
97+ status = brgemm_kernel_create (&kernel, * desc);
9798 assert (status == dnnl::impl::status::success &&
9899 " Failed to JIT BRGEMM kernel" );
99100
100101 brgemm_attr_t dnnl_attrs;
101- brgemm_desc_set_attr (& desc, dnnl_attrs);
102+ brgemm_desc_set_attr (desc, dnnl_attrs);
102103
103104 // TODO(haixin): Reuse identical palettes across kernels
104105 std::shared_ptr<char []> palette_buffer;
105- if (desc. is_tmm ) {
106+ if (desc-> is_tmm ) {
106107 palette_buffer.reset (new char [PALETTE_SIZE]);
107- dnnl::impl::status_t status = brgemm_init_tiles (desc, palette_buffer.get ());
108+ dnnl::impl::status_t status =
109+ brgemm_init_tiles (*desc, palette_buffer.get ());
108110 assert (status == dnnl::impl::status::success &&
109111 " Failed to initialize palette for BRGEMM" );
110112 }
111113
112114 write_lock_guard_t g (g_brgemm_lock);
113- g_cache.push_back (brgemm_cache_info_t {& desc, kernel, palette_buffer});
115+ g_cache.push_back (brgemm_cache_info_t {desc, kernel, palette_buffer});
114116 return g_cache.size () - 1 ;
115117}
116118
0 commit comments