@@ -53,15 +53,24 @@ using read_lock_guard_t = std::shared_lock<std::shared_mutex>;
5353using write_lock_guard_t = std::unique_lock<std::shared_mutex>;
5454static std::shared_mutex g_brgemm_lock;
5555
56- static std::vector<brgemm_desc_t > g_brgemm_desc_list;
57- static std::vector<brgemm_kernel_t *> g_brgemm_kernel_list;
58- static std::vector<std::unique_ptr<char []>> g_brgemm_palette;
56+ struct brgemm_cache_info_t {
57+ brgemm_desc_t desc;
58+ brgemm_kernel_t *kernel;
59+ std::shared_ptr<char []> palette;
60+ };
61+
62+ static std::vector<brgemm_cache_info_t > g_cache;
5963
6064// TODO(haixin): use syscall to determine page size?
6165static constexpr size_t SCRATCH_SIZE = 2 * 4096 ;
6266// TODO(haixin): need to use custom thread management for scratch in the future?
6367static thread_local char scratch[SCRATCH_SIZE] = {0 };
6468
69+ static std::unordered_map<int64_t , brgemm_cache_info_t > &get_tl_cache () {
70+ thread_local std::unordered_map<int64_t , brgemm_cache_info_t > tl_cache;
71+ return tl_cache;
72+ }
73+
6574extern " C" {
6675
6776int64_t dnnl_brgemm_dispatch (int64_t M, int64_t N, int64_t K, int64_t LDA,
@@ -93,33 +102,33 @@ int64_t dnnl_brgemm_dispatch(int64_t M, int64_t N, int64_t K, int64_t LDA,
93102 brgemm_desc_set_attr (&desc, dnnl_attrs);
94103
95104 // TODO(haixin): Reuse identical palettes across kernels
96- char * palette_buffer = nullptr ;
105+ std::shared_ptr< char []> palette_buffer;
97106 if (desc.is_tmm ) {
98- palette_buffer = new char [PALETTE_SIZE];
99- dnnl::impl::status_t status = brgemm_init_tiles (desc, palette_buffer);
107+ palette_buffer. reset ( new char [PALETTE_SIZE]) ;
108+ dnnl::impl::status_t status = brgemm_init_tiles (desc, palette_buffer. get () );
100109 assert (status == dnnl::impl::status::success &&
101110 " Failed to initialize palette for BRGEMM" );
102111 }
103112
104113 write_lock_guard_t g (g_brgemm_lock);
105- g_brgemm_desc_list.push_back (desc);
106- g_brgemm_kernel_list.push_back (kernel);
107- g_brgemm_palette.emplace_back (palette_buffer);
108-
109- return g_brgemm_desc_list.size () - 1 ;
114+ g_cache.push_back (brgemm_cache_info_t {desc, kernel, palette_buffer});
115+ return g_cache.size () - 1 ;
110116}
111117
112118void dnnl_brgemm_tileconfig (int64_t kernel_idx) {
113- char *palette_buffer = nullptr ;
114- {
119+ assert (kernel_idx >= 0 && " Invalid kernel handler" );
120+ auto &tl_cache = get_tl_cache ();
121+ auto it = tl_cache.find (kernel_idx);
122+ if (it == tl_cache.end ()) {
115123 read_lock_guard_t g (g_brgemm_lock);
116- assert (kernel_idx >= 0 && kernel_idx < (int64_t )g_brgemm_desc_list.size () &&
117- " Invalid kernel handler" );
118- brgemm_desc_t &desc = g_brgemm_desc_list[kernel_idx];
119- if (!desc.is_tmm ) {
120- return ;
121- }
122- palette_buffer = g_brgemm_palette[kernel_idx].get ();
124+ assert (kernel_idx < (int64_t )g_cache.size () && " Invalid kernel handler" );
125+ it = tl_cache.insert ({kernel_idx, g_cache[kernel_idx]}).first ;
126+ }
127+ brgemm_desc_t &desc = it->second .desc ;
128+ char *palette_buffer = it->second .palette .get ();
129+
130+ if (!desc.is_tmm ) {
131+ return ;
123132 }
124133
125134 assert (palette_buffer != nullptr && " Invalid palette for BRGEMM kernel" );
@@ -137,24 +146,29 @@ void dnnl_brgemm_tilerelease() {
137146void dnnl_brgemm_execute (int64_t kernel_idx, void *A, uint64_t A_offset,
138147 void *B, uint64_t B_offset, void *C, uint64_t C_offset,
139148 int num) {
140- brgemm_kernel_t *kernel = nullptr ;
141- size_t A_offset_in_bytes;
142- size_t B_offset_in_bytes;
143- size_t C_offset_in_bytes;
144- {
149+ auto &tl_cache = get_tl_cache ();
150+ if (tl_cache.find (kernel_idx) == tl_cache.end ()) {
145151 read_lock_guard_t g (g_brgemm_lock);
146- assert (kernel_idx >= 0 && kernel_idx < (int64_t )g_brgemm_desc_list .size () &&
152+ assert (kernel_idx >= 0 && kernel_idx < (int64_t )g_cache .size () &&
147153 " Invalid kernel handler" );
148-
149- brgemm_desc_t &desc = g_brgemm_desc_list[kernel_idx];
150- kernel = g_brgemm_kernel_list[kernel_idx];
151-
152- A_offset_in_bytes = dnnl::impl::types::data_type_size (desc.dt_a ) * A_offset;
153- B_offset_in_bytes = dnnl::impl::types::data_type_size (desc.dt_b ) * B_offset;
154- C_offset_in_bytes = dnnl::impl::types::data_type_size (desc.dt_c ) * C_offset;
154+ auto updated_cache =
155+ tl_cache.insert (std::make_pair (kernel_idx, g_cache[kernel_idx]));
156+ assert (updated_cache.second && " insert into thread local cache" );
155157 }
158+ auto it = tl_cache.find (kernel_idx);
159+ brgemm_kernel_t *kernel = it->second .kernel ;
160+ brgemm_desc_t *desc_ptr = &it->second .desc ;
156161
157162 assert (kernel && " Invalid brgemm kernel pointer" );
163+ assert (desc_ptr && " Invalid brgemm descriptor pointer" );
164+
165+ size_t A_offset_in_bytes =
166+ dnnl::impl::types::data_type_size (desc_ptr->dt_a ) * A_offset;
167+ size_t B_offset_in_bytes =
168+ dnnl::impl::types::data_type_size (desc_ptr->dt_b ) * B_offset;
169+ size_t C_offset_in_bytes =
170+ dnnl::impl::types::data_type_size (desc_ptr->dt_c ) * C_offset;
171+
158172 char *A_arith = (char *)A;
159173 char *B_arith = (char *)B;
160174 char *C_arith = (char *)C;
0 commit comments