Skip to content

Commit 7a3eab8

Browse files
committed
Add AnyKernelDescriptor class to replace KernelRegistry::CacheKey
1 parent 47c320d commit 7a3eab8

File tree

2 files changed

+66
-68
lines changed

2 files changed

+66
-68
lines changed

include/kernel_launcher/registry.h

Lines changed: 65 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -19,82 +19,91 @@ struct KernelDescriptor {
1919
}
2020
};
2121

22-
struct KernelRegistry {
22+
struct AnyKernelDescriptor {
23+
AnyKernelDescriptor(AnyKernelDescriptor&&) noexcept = default;
24+
AnyKernelDescriptor(const AnyKernelDescriptor&) = default;
25+
26+
template<typename D>
27+
AnyKernelDescriptor(D&& descriptor) {
28+
using T = typename std::decay<D>::type;
29+
descriptor_type_ = type_of<T>();
30+
descriptor_ = std::make_unique<T>(std::forward<D>(descriptor));
31+
hash_ = hash_fields(descriptor_type_, descriptor_->hash());
32+
}
33+
34+
const KernelDescriptor& descriptor() const {
35+
return *descriptor_;
36+
}
37+
38+
hash_t hash() const {
39+
return hash_;
40+
}
41+
42+
bool operator==(const AnyKernelDescriptor& that) const {
43+
return that.hash_ == hash_ && that.descriptor_type_ == descriptor_type_
44+
&& that.descriptor_->equals(*descriptor_);
45+
}
46+
47+
bool operator!=(const AnyKernelDescriptor& that) const {
48+
return !(*this == that);
49+
}
50+
2351
private:
24-
struct CacheKey {
25-
struct hasher {
26-
size_t operator()(const CacheKey& key) const;
27-
};
28-
29-
struct equals {
30-
bool operator()(const CacheKey& lhs, const CacheKey& rhs) const;
31-
};
32-
33-
CacheKey(CacheKey&&) noexcept = default;
34-
35-
template<typename D>
36-
explicit CacheKey(D&& descriptor) {
37-
using T = typename std::decay<D>::type;
38-
descriptor_type_ = type_of<T>();
39-
descriptor_ = std::make_unique<T>(std::forward<D>(descriptor));
40-
hash_ = hash_fields(descriptor_type_, descriptor_->hash());
41-
}
42-
43-
const KernelDescriptor& descriptor() const {
44-
return *descriptor_;
45-
}
46-
47-
private:
48-
hash_t hash_;
49-
TypeInfo descriptor_type_;
50-
std::unique_ptr<KernelDescriptor> descriptor_;
51-
};
52-
53-
public:
52+
hash_t hash_;
53+
TypeInfo descriptor_type_;
54+
std::shared_ptr<KernelDescriptor> descriptor_;
55+
};
56+
}
57+
58+
namespace std {
59+
template <>
60+
struct hash<kernel_launcher::AnyKernelDescriptor> {
61+
size_t operator()(const kernel_launcher::AnyKernelDescriptor& d) const {
62+
return d.hash();
63+
}
64+
};
65+
}
66+
67+
namespace kernel_launcher {
68+
struct KernelRegistry {
5469
explicit KernelRegistry(
5570
Compiler compiler = default_compiler(),
5671
WisdomSettings settings = default_wisdom_settings()) :
5772
compiler_(std::move(compiler)),
5873
settings_(std::move(settings)) {}
5974

60-
template<typename D>
61-
WisdomKernel& lookup(D&& descriptor) const {
62-
return lookup_internal(CacheKey(std::forward<D>(descriptor)));
75+
WisdomKernel& lookup(AnyKernelDescriptor descriptor) const {
76+
return lookup_internal(std::move(descriptor));
6377
}
6478

65-
template<typename D>
66-
WisdomKernelLaunch
67-
instantiate(D&& descriptor, cudaStream_t stream, ProblemSize problem_size)
68-
const {
69-
return lookup(std::forward<D>(descriptor))
70-
.instantiate(stream, problem_size);
79+
WisdomKernelLaunch instantiate(
80+
AnyKernelDescriptor descriptor,
81+
cudaStream_t stream,
82+
ProblemSize problem_size) const {
83+
return lookup(std::move(descriptor)).instantiate(stream, problem_size);
7184
}
7285

73-
template<typename D>
74-
WisdomKernelLaunch
75-
operator()(D&& descriptor, cudaStream_t stream, ProblemSize problem_size)
76-
const {
77-
return instantiate(std::forward<D>(descriptor), stream, problem_size);
86+
WisdomKernelLaunch operator()(
87+
AnyKernelDescriptor descriptor,
88+
cudaStream_t stream,
89+
ProblemSize problem_size) const {
90+
return instantiate(std::move(descriptor), stream, problem_size);
7891
}
7992

80-
template<typename D>
8193
WisdomKernelLaunch
82-
operator()(D&& descriptor, ProblemSize problem_size) const {
83-
return instantiate(std::forward<D>(descriptor), nullptr, problem_size);
94+
operator()(AnyKernelDescriptor descriptor, ProblemSize problem_size) const {
95+
return instantiate(std::move(descriptor), nullptr, problem_size);
8496
}
8597

8698
private:
87-
WisdomKernel& lookup_internal(CacheKey key) const;
99+
WisdomKernel& lookup_internal(AnyKernelDescriptor key) const;
88100

89101
Compiler compiler_;
90102
WisdomSettings settings_;
91103
mutable std::mutex mutex_;
92-
mutable std::unordered_map<
93-
CacheKey,
94-
std::unique_ptr<WisdomKernel>,
95-
CacheKey::hasher,
96-
CacheKey::equals>
97-
cache_ = {};
104+
mutable std::
105+
unordered_map<AnyKernelDescriptor, std::unique_ptr<WisdomKernel>>
106+
cache_ = {};
98107
};
99108

100109
const KernelRegistry& default_registry();
@@ -113,4 +122,5 @@ WisdomKernelLaunch launch(D&& descriptor, ProblemSize size) {
113122

114123
} // namespace kernel_launcher
115124

125+
116126
#endif //KERNEL_LAUNCHER_CACHE_H

src/registry.cpp

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,19 +11,7 @@ const KernelRegistry& default_registry() {
1111
return *global_default_registry;
1212
}
1313

14-
size_t KernelRegistry::CacheKey::hasher::operator()(const CacheKey& key) const {
15-
return key.hash_;
16-
}
17-
18-
bool KernelRegistry::CacheKey::equals::operator()(
19-
const CacheKey& lhs,
20-
const CacheKey& rhs) const {
21-
return lhs.hash_ == rhs.hash_
22-
&& lhs.descriptor_type_ == rhs.descriptor_type_
23-
&& lhs.descriptor_->equals(*rhs.descriptor_);
24-
}
25-
26-
WisdomKernel& KernelRegistry::lookup_internal(CacheKey key) const {
14+
WisdomKernel& KernelRegistry::lookup_internal(AnyKernelDescriptor key) const {
2715
std::lock_guard<std::mutex> guard(mutex_);
2816

2917
auto it = cache_.find(key);

0 commit comments

Comments
 (0)