diff --git a/lib/gc/ExecutionEngine/CPURuntime/Microkernel/BrgemmOnednn.cpp b/lib/gc/ExecutionEngine/CPURuntime/Microkernel/BrgemmOnednn.cpp index 8e91a0f9..25138dbd 100644 --- a/lib/gc/ExecutionEngine/CPURuntime/Microkernel/BrgemmOnednn.cpp +++ b/lib/gc/ExecutionEngine/CPURuntime/Microkernel/BrgemmOnednn.cpp @@ -48,6 +48,8 @@ __attribute__((weak)) void print_verbose_header() {} } // namespace dnnl static constexpr int PALETTE_SIZE = 64; +static constexpr int DEFAULT_KERNEL_SIZE = 1024; +static constexpr int MAX_KERNEL_SIZE = 2048; using read_lock_guard_t = std::shared_lock; using write_lock_guard_t = std::unique_lock; @@ -56,81 +58,78 @@ static std::shared_mutex g_brgemm_lock; struct brgemm_cache_info_t { brgemm_desc_t desc; brgemm_kernel_t *kernel; - std::shared_ptr palette; + std::unique_ptr palette; }; -static std::vector g_cache; +static std::vector g_cache(DEFAULT_KERNEL_SIZE); +static int64_t g_kernel_id = -1; // TODO(haixin): use syscall to determine page size? static constexpr size_t SCRATCH_SIZE = 2 * 4096; // TODO(haixin): need to use custom thread management for scratch in the future? static thread_local char scratch[SCRATCH_SIZE] = {0}; -static std::unordered_map &get_tl_cache() { - thread_local std::unordered_map tl_cache; - return tl_cache; -} - extern "C" { int64_t dnnl_brgemm_dispatch(int64_t M, int64_t N, int64_t K, int64_t LDA, int64_t LDB, int64_t LDC, int64_t stride_a, int64_t stride_b, float beta, int64_t dtypeA, int64_t dtypeB) { - brgemm_desc_t desc; - brgemm_kernel_t *kernel; - auto dnnl_dtypeA = static_cast(dtypeA); auto dnnl_dtypeB = static_cast(dtypeB); int64_t dtypeA_size = dnnl::impl::types::data_type_size(dnnl_dtypeA); int64_t dtypeB_size = dnnl::impl::types::data_type_size(dnnl_dtypeB); brgemm_strides_t stride_info{stride_a * dtypeA_size, stride_b * dtypeB_size}; + write_lock_guard_t g(g_brgemm_lock); + g_kernel_id++; + assert(g_kernel_id < MAX_KERNEL_SIZE && + "Too many brgemm kernels are created"); + if (g_kernel_id >= DEFAULT_KERNEL_SIZE) { + if (g_kernel_id >= (int64_t)g_cache.size()) { + g_cache.resize(g_kernel_id + 1); + } + } + dnnl::impl::status_t status = brgemm_desc_init( - &desc, cpu_isa_t::isa_undef, brgemm_batch_kind_t::brgemm_strd, - dnnl_dtypeA, dnnl_dtypeB, /*transA=*/false, /*transB=*/false, - brgemm_layout_t::brgemm_row_major, 1.0f, beta, LDA, LDB, LDC, M, N, K, - &stride_info); + &g_cache[g_kernel_id].desc, cpu_isa_t::isa_undef, + brgemm_batch_kind_t::brgemm_strd, dnnl_dtypeA, dnnl_dtypeB, + /*transA=*/false, /*transB=*/false, brgemm_layout_t::brgemm_row_major, + 1.0f, beta, LDA, LDB, LDC, M, N, K, &stride_info); assert(status == dnnl::impl::status::success && "Failed to initialize BRGEMM descriptor"); - status = brgemm_kernel_create(&kernel, desc); + status = brgemm_kernel_create(&g_cache[g_kernel_id].kernel, + g_cache[g_kernel_id].desc); assert(status == dnnl::impl::status::success && "Failed to JIT BRGEMM kernel"); brgemm_attr_t dnnl_attrs; - brgemm_desc_set_attr(&desc, dnnl_attrs); + brgemm_desc_set_attr(&g_cache[g_kernel_id].desc, dnnl_attrs); - // TODO(haixin): Reuse identical palettes across kernels - std::shared_ptr palette_buffer; - if (desc.is_tmm) { - palette_buffer.reset(new char[PALETTE_SIZE]); - dnnl::impl::status_t status = brgemm_init_tiles(desc, palette_buffer.get()); + if (g_cache[g_kernel_id].desc.is_tmm) { + g_cache[g_kernel_id].palette.reset(new char[PALETTE_SIZE]); + status = brgemm_init_tiles(g_cache[g_kernel_id].desc, + g_cache[g_kernel_id].palette.get()); assert(status == dnnl::impl::status::success && "Failed to initialize palette for BRGEMM"); } - write_lock_guard_t g(g_brgemm_lock); - g_cache.push_back(brgemm_cache_info_t{desc, kernel, palette_buffer}); - return g_cache.size() - 1; + return g_kernel_id; } void dnnl_brgemm_tileconfig(int64_t kernel_idx) { - assert(kernel_idx >= 0 && "Invalid kernel handler"); - auto &tl_cache = get_tl_cache(); - auto it = tl_cache.find(kernel_idx); - if (it == tl_cache.end()) { - read_lock_guard_t g(g_brgemm_lock); - assert(kernel_idx < (int64_t)g_cache.size() && "Invalid kernel handler"); - it = tl_cache.insert({kernel_idx, g_cache[kernel_idx]}).first; + std::unique_ptr lock_guard; + if (kernel_idx >= DEFAULT_KERNEL_SIZE) { + lock_guard = std::make_unique(g_brgemm_lock); } - brgemm_desc_t &desc = it->second.desc; - char *palette_buffer = it->second.palette.get(); - + assert(kernel_idx >= 0 && kernel_idx < (int64_t)g_cache.size() && + "Invalid kernel handler"); + brgemm_desc_t &desc = g_cache[kernel_idx].desc; if (!desc.is_tmm) { return; } - + char *palette_buffer = g_cache[kernel_idx].palette.get(); assert(palette_buffer != nullptr && "Invalid palette for BRGEMM kernel"); amx_tile_configure(palette_buffer); } @@ -146,35 +145,26 @@ void dnnl_brgemm_tilerelease() { void dnnl_brgemm_execute(int64_t kernel_idx, void *A, uint64_t A_offset, void *B, uint64_t B_offset, void *C, uint64_t C_offset, int num) { - auto &tl_cache = get_tl_cache(); - if (tl_cache.find(kernel_idx) == tl_cache.end()) { - read_lock_guard_t g(g_brgemm_lock); - assert(kernel_idx >= 0 && kernel_idx < (int64_t)g_cache.size() && - "Invalid kernel handler"); - auto updated_cache = - tl_cache.insert(std::make_pair(kernel_idx, g_cache[kernel_idx])); - assert(updated_cache.second && "insert into thread local cache"); + std::unique_ptr lock_guard; + if (kernel_idx >= DEFAULT_KERNEL_SIZE) { + lock_guard = std::make_unique(g_brgemm_lock); } - auto it = tl_cache.find(kernel_idx); - brgemm_kernel_t *kernel = it->second.kernel; - brgemm_desc_t *desc_ptr = &it->second.desc; - + assert(kernel_idx >= 0 && kernel_idx < (int64_t)g_cache.size() && + "Invalid kernel handler"); + brgemm_desc_t &desc = g_cache[kernel_idx].desc; + brgemm_kernel_t *kernel = g_cache[kernel_idx].kernel; assert(kernel && "Invalid brgemm kernel pointer"); - assert(desc_ptr && "Invalid brgemm descriptor pointer"); - size_t A_offset_in_bytes = - dnnl::impl::types::data_type_size(desc_ptr->dt_a) * A_offset; + dnnl::impl::types::data_type_size(desc.dt_a) * A_offset; size_t B_offset_in_bytes = - dnnl::impl::types::data_type_size(desc_ptr->dt_b) * B_offset; + dnnl::impl::types::data_type_size(desc.dt_b) * B_offset; size_t C_offset_in_bytes = - dnnl::impl::types::data_type_size(desc_ptr->dt_c) * C_offset; - - char *A_arith = (char *)A; - char *B_arith = (char *)B; - char *C_arith = (char *)C; - brgemm_kernel_execute(kernel, num, (void *)(A_arith + A_offset_in_bytes), - (void *)(B_arith + B_offset_in_bytes), nullptr, - (void *)(C_arith + C_offset_in_bytes), (void *)scratch); + dnnl::impl::types::data_type_size(desc.dt_c) * C_offset; + char *A_arith = static_cast(A) + A_offset_in_bytes; + char *B_arith = static_cast(B) + B_offset_in_bytes; + char *C_arith = static_cast(C) + C_offset_in_bytes; + brgemm_kernel_execute(kernel, num, A_arith, B_arith, nullptr, C_arith, + scratch); } }