@@ -19,82 +19,91 @@ struct KernelDescriptor {
1919 }
2020};
2121
22- struct KernelRegistry {
22+ struct AnyKernelDescriptor {
23+ AnyKernelDescriptor (AnyKernelDescriptor&&) noexcept = default ;
24+ AnyKernelDescriptor (const AnyKernelDescriptor&) = default ;
25+
26+ template <typename D>
27+ AnyKernelDescriptor (D&& descriptor) {
28+ using T = typename std::decay<D>::type;
29+ descriptor_type_ = type_of<T>();
30+ descriptor_ = std::make_unique<T>(std::forward<D>(descriptor));
31+ hash_ = hash_fields (descriptor_type_, descriptor_->hash ());
32+ }
33+
34+ const KernelDescriptor& descriptor () const {
35+ return *descriptor_;
36+ }
37+
38+ hash_t hash () const {
39+ return hash_;
40+ }
41+
42+ bool operator ==(const AnyKernelDescriptor& that) const {
43+ return that.hash_ == hash_ && that.descriptor_type_ == descriptor_type_
44+ && that.descriptor_ ->equals (*descriptor_);
45+ }
46+
47+ bool operator !=(const AnyKernelDescriptor& that) const {
48+ return !(*this == that);
49+ }
50+
2351 private:
24- struct CacheKey {
25- struct hasher {
26- size_t operator ()(const CacheKey& key) const ;
27- };
28-
29- struct equals {
30- bool operator ()(const CacheKey& lhs, const CacheKey& rhs) const ;
31- };
32-
33- CacheKey (CacheKey&&) noexcept = default ;
34-
35- template <typename D>
36- explicit CacheKey (D&& descriptor) {
37- using T = typename std::decay<D>::type;
38- descriptor_type_ = type_of<T>();
39- descriptor_ = std::make_unique<T>(std::forward<D>(descriptor));
40- hash_ = hash_fields (descriptor_type_, descriptor_->hash ());
41- }
42-
43- const KernelDescriptor& descriptor () const {
44- return *descriptor_;
45- }
46-
47- private:
48- hash_t hash_;
49- TypeInfo descriptor_type_;
50- std::unique_ptr<KernelDescriptor> descriptor_;
51- };
52-
53- public:
52+ hash_t hash_;
53+ TypeInfo descriptor_type_;
54+ std::shared_ptr<KernelDescriptor> descriptor_;
55+ };
56+ }
57+
58+ namespace std {
59+ template <>
60+ struct hash <kernel_launcher::AnyKernelDescriptor> {
61+ size_t operator ()(const kernel_launcher::AnyKernelDescriptor& d) const {
62+ return d.hash ();
63+ }
64+ };
65+ }
66+
67+ namespace kernel_launcher {
68+ struct KernelRegistry {
5469 explicit KernelRegistry (
5570 Compiler compiler = default_compiler(),
5671 WisdomSettings settings = default_wisdom_settings()) :
5772 compiler_(std::move(compiler)),
5873 settings_(std::move(settings)) {}
5974
60- template <typename D>
61- WisdomKernel& lookup (D&& descriptor) const {
62- return lookup_internal (CacheKey (std::forward<D>(descriptor)));
75+ WisdomKernel& lookup (AnyKernelDescriptor descriptor) const {
76+ return lookup_internal (std::move (descriptor));
6377 }
6478
65- template <typename D>
66- WisdomKernelLaunch
67- instantiate (D&& descriptor, cudaStream_t stream, ProblemSize problem_size)
68- const {
69- return lookup (std::forward<D>(descriptor))
70- .instantiate (stream, problem_size);
79+ WisdomKernelLaunch instantiate (
80+ AnyKernelDescriptor descriptor,
81+ cudaStream_t stream,
82+ ProblemSize problem_size) const {
83+ return lookup (std::move (descriptor)).instantiate (stream, problem_size);
7184 }
7285
73- template < typename D>
74- WisdomKernelLaunch
75- operator ()(D&& descriptor, cudaStream_t stream, ProblemSize problem_size)
76- const {
77- return instantiate (std::forward<D> (descriptor), stream, problem_size);
86+ WisdomKernelLaunch operator ()(
87+ AnyKernelDescriptor descriptor,
88+ cudaStream_t stream,
89+ ProblemSize problem_size) const {
90+ return instantiate (std::move (descriptor), stream, problem_size);
7891 }
7992
80- template <typename D>
8193 WisdomKernelLaunch
82- operator ()(D&& descriptor, ProblemSize problem_size) const {
83- return instantiate (std::forward<D> (descriptor), nullptr , problem_size);
94+ operator ()(AnyKernelDescriptor descriptor, ProblemSize problem_size) const {
95+ return instantiate (std::move (descriptor), nullptr , problem_size);
8496 }
8597
8698 private:
87- WisdomKernel& lookup_internal (CacheKey key) const ;
99+ WisdomKernel& lookup_internal (AnyKernelDescriptor key) const ;
88100
89101 Compiler compiler_;
90102 WisdomSettings settings_;
91103 mutable std::mutex mutex_;
92- mutable std::unordered_map<
93- CacheKey,
94- std::unique_ptr<WisdomKernel>,
95- CacheKey::hasher,
96- CacheKey::equals>
97- cache_ = {};
104+ mutable std::
105+ unordered_map<AnyKernelDescriptor, std::unique_ptr<WisdomKernel>>
106+ cache_ = {};
98107};
99108
100109const KernelRegistry& default_registry ();
@@ -113,4 +122,5 @@ WisdomKernelLaunch launch(D&& descriptor, ProblemSize size) {
113122
114123} // namespace kernel_launcher
115124
125+
116126#endif // KERNEL_LAUNCHER_CACHE_H
0 commit comments