|
| 1 | +// SPDX-FileCopyrightText: 2020 - 2023 Intel Corporation |
| 2 | +// |
| 3 | +// SPDX-License-Identifier: Apache-2.0 |
| 4 | + |
| 5 | +#include "kernel_caching.h" |
| 6 | +#include <unordered_map> |
| 7 | + |
| 8 | +extern "C" |
| 9 | +{ |
| 10 | +#include "dpctl_capi.h" |
| 11 | +#include "dpctl_sycl_interface.h" |
| 12 | + |
| 13 | +#include "_dbg_printer.h" |
| 14 | + |
| 15 | +#include "numba/core/runtime/nrt_external.h" |
| 16 | +} |
| 17 | + |
| 18 | +#include "syclinterface/dpctl_sycl_type_casters.hpp" |
| 19 | +#include "tools/boost_hash.hpp" |
| 20 | +#include "tools/dpctl.hpp" |
| 21 | + |
| 22 | +using CacheKey = std::tuple<DPCTLSyclContextRef, DPCTLSyclDeviceRef, size_t>; |
| 23 | + |
| 24 | +namespace std |
| 25 | +{ |
| 26 | +template <> struct hash<CacheKey> |
| 27 | +{ |
| 28 | + size_t operator()(const CacheKey &ck) const |
| 29 | + { |
| 30 | + std::size_t seed = 0; |
| 31 | + boost::hash_combine(seed, std::get<DPCTLSyclDeviceRef>(ck)); |
| 32 | + boost::hash_combine(seed, std::get<DPCTLSyclContextRef>(ck)); |
| 33 | + boost::hash_detail::hash_combine_impl(seed, std::get<size_t>(ck)); |
| 34 | + return seed; |
| 35 | + } |
| 36 | +}; |
| 37 | +template <> struct equal_to<CacheKey> |
| 38 | +{ |
| 39 | + constexpr bool operator()(const CacheKey &lhs, const CacheKey &rhs) const |
| 40 | + { |
| 41 | + return DPCTLDevice_AreEq(std::get<DPCTLSyclDeviceRef>(lhs), |
| 42 | + std::get<DPCTLSyclDeviceRef>(rhs)) && |
| 43 | + DPCTLContext_AreEq(std::get<DPCTLSyclContextRef>(lhs), |
| 44 | + std::get<DPCTLSyclContextRef>(rhs)) && |
| 45 | + std::get<size_t>(lhs) == std::get<size_t>(rhs); |
| 46 | + } |
| 47 | +}; |
| 48 | +} // namespace std |
| 49 | + |
| 50 | +// TODO: add cache cleaning |
| 51 | +// https://github.com/IntelPython/numba-dpex/issues/1240 |
| 52 | +std::unordered_map<CacheKey, DPCTLSyclKernelRef> sycl_kernel_cache = |
| 53 | + std::unordered_map<CacheKey, DPCTLSyclKernelRef>(); |
| 54 | + |
| 55 | +template <class M, class Key, class F> |
| 56 | +typename M::mapped_type &get_else_compute(M &m, Key const &k, F f) |
| 57 | +{ |
| 58 | + typedef typename M::mapped_type V; |
| 59 | + std::pair<typename M::iterator, bool> r = |
| 60 | + m.insert(typename M::value_type(k, V())); |
| 61 | + V &v = r.first->second; |
| 62 | + if (r.second) { |
| 63 | + DPEXRT_DEBUG(drt_debug_print("DPEXRT-DEBUG: building kernel.\n");); |
| 64 | + f(v); |
| 65 | + } |
| 66 | + else { |
| 67 | + DPEXRT_DEBUG(drt_debug_print("DPEXRT-DEBUG: using cached kernel.\n");); |
| 68 | + DPCTLDevice_Delete(std::get<DPCTLSyclDeviceRef>(k)); |
| 69 | + DPCTLContext_Delete(std::get<DPCTLSyclContextRef>(k)); |
| 70 | + } |
| 71 | + return v; |
| 72 | +} |
| 73 | + |
| 74 | +extern "C" |
| 75 | +{ |
| 76 | + DPCTLSyclKernelRef DPEXRT_build_or_get_kernel(const DPCTLSyclContextRef ctx, |
| 77 | + const DPCTLSyclDeviceRef dev, |
| 78 | + size_t il_hash, |
| 79 | + const char *il, |
| 80 | + size_t il_length, |
| 81 | + const char *compile_opts, |
| 82 | + const char *kernel_name) |
| 83 | + { |
| 84 | + DPEXRT_DEBUG( |
| 85 | + drt_debug_print("DPEXRT-DEBUG: in build or get kernel.\n");); |
| 86 | + |
| 87 | + CacheKey key = std::make_tuple(ctx, dev, il_hash); |
| 88 | + |
| 89 | + DPEXRT_DEBUG(auto ctx_hash = std::hash<DPCTLSyclContextRef>{}(ctx); |
| 90 | + auto dev_hash = std::hash<DPCTLSyclDeviceRef>{}(dev); |
| 91 | + drt_debug_print("DPEXRT-DEBUG: key hashes: %d %d %d.\n", |
| 92 | + ctx_hash, dev_hash, il_hash);); |
| 93 | + |
| 94 | + auto k_ref = get_else_compute( |
| 95 | + sycl_kernel_cache, key, |
| 96 | + [ctx, dev, il, il_length, compile_opts, |
| 97 | + kernel_name](DPCTLSyclKernelRef &k_ref) { |
| 98 | + auto kb_ref = DPCTLKernelBundle_CreateFromSpirv( |
| 99 | + ctx, dev, il, il_length, compile_opts); |
| 100 | + k_ref = DPCTLKernelBundle_GetKernel(kb_ref, kernel_name); |
| 101 | + DPCTLKernelBundle_Delete(kb_ref); |
| 102 | + }); |
| 103 | + |
| 104 | + DPEXRT_DEBUG(drt_debug_print("DPEXRT-DEBUG: kernel hash size: %d.\n", |
| 105 | + sycl_kernel_cache.size());); |
| 106 | + |
| 107 | + return DPCTLKernel_Copy(k_ref); |
| 108 | + } |
| 109 | + |
| 110 | + size_t DPEXRT_kernel_cache_size() { return sycl_kernel_cache.size(); } |
| 111 | +} |
0 commit comments