Skip to content

Commit 440992e

Browse files
committed
Add support for preheaders in KernelBuilder
1 parent 7a3eab8 commit 440992e

File tree

4 files changed

+37
-14
lines changed

4 files changed

+37
-14
lines changed

include/kernel_launcher/kernel.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ struct KernelBuilder: ConfigSpace {
9191
KernelBuilder& assertion(TypedExpr<bool> e);
9292
KernelBuilder& define(std::string name, TypedExpr<std::string> value);
9393
KernelBuilder& compiler_flag(TypedExpr<std::string> opt);
94+
KernelBuilder& include_header(KernelSource source);
9495

9596
template<typename T, typename... Ts>
9697
KernelBuilder& template_args(T&& first, Ts&&... rest) {
@@ -152,6 +153,7 @@ struct KernelBuilder: ConfigSpace {
152153
private:
153154
std::string kernel_name_;
154155
KernelSource kernel_source_;
156+
std::vector<KernelSource> preheaders_;
155157
std::array<TypedExpr<uint32_t>, 3> block_size_ = {1u, 1u, 1u};
156158
std::array<TypedExpr<uint32_t>, 3> grid_size_ = {1u, 1u, 1u};
157159
TypedExpr<uint32_t> shared_mem_ = {0u};

include/kernel_launcher/registry.h

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,14 @@ struct KernelDescriptor {
2121

2222
struct AnyKernelDescriptor {
2323
AnyKernelDescriptor(AnyKernelDescriptor&&) noexcept = default;
24+
AnyKernelDescriptor(AnyKernelDescriptor&) noexcept = default;
2425
AnyKernelDescriptor(const AnyKernelDescriptor&) = default;
2526

2627
template<typename D>
2728
AnyKernelDescriptor(D&& descriptor) {
2829
using T = typename std::decay<D>::type;
2930
descriptor_type_ = type_of<T>();
30-
descriptor_ = std::make_unique<T>(std::forward<D>(descriptor));
31+
descriptor_ = std::make_shared<T>(std::forward<D>(descriptor));
3132
hash_ = hash_fields(descriptor_type_, descriptor_->hash());
3233
}
3334

@@ -53,16 +54,16 @@ struct AnyKernelDescriptor {
5354
TypeInfo descriptor_type_;
5455
std::shared_ptr<KernelDescriptor> descriptor_;
5556
};
56-
}
57+
} // namespace kernel_launcher
5758

5859
namespace std {
59-
template <>
60+
template<>
6061
struct hash<kernel_launcher::AnyKernelDescriptor> {
6162
size_t operator()(const kernel_launcher::AnyKernelDescriptor& d) const {
6263
return d.hash();
6364
}
6465
};
65-
}
66+
} // namespace std
6667

6768
namespace kernel_launcher {
6869
struct KernelRegistry {
@@ -108,19 +109,16 @@ struct KernelRegistry {
108109

109110
const KernelRegistry& default_registry();
110111

111-
template<typename D>
112-
WisdomKernelLaunch
113-
launch(D&& descriptor, cudaStream_t stream, ProblemSize size) {
114-
return default_registry().instantiate(
115-
std::forward<D>(descriptor, stream, size));
112+
inline WisdomKernelLaunch
113+
launch(AnyKernelDescriptor descriptor, cudaStream_t stream, ProblemSize size) {
114+
return default_registry().instantiate(descriptor, stream, size);
116115
}
117116

118-
template<typename D>
119-
WisdomKernelLaunch launch(D&& descriptor, ProblemSize size) {
120-
return launch(std::forward<D>(descriptor), (cudaStream_t) nullptr, size);
117+
inline WisdomKernelLaunch
118+
launch(AnyKernelDescriptor descriptor, ProblemSize size) {
119+
return launch(descriptor, (cudaStream_t) nullptr, size);
121120
}
122121

123122
} // namespace kernel_launcher
124123

125-
126124
#endif //KERNEL_LAUNCHER_CACHE_H

src/export.cpp

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,20 @@ struct KernelBuilderSerializerHack {
175175
}
176176

177177
static json builder_to_json(const KernelBuilder& builder) {
178-
std::unordered_map<std::string, json> defines;
178+
std::vector<json> headers;
179+
for (const auto& source : builder.preheaders_) {
180+
json content = nullptr;
181+
if (source.content() != nullptr) {
182+
content = *source.content();
183+
}
184+
185+
headers.push_back({
186+
{"file", source.file_name()},
187+
{"content", std::move(content)},
188+
});
189+
}
190+
191+
json defines;
179192
for (const auto& p : builder.defines_) {
180193
defines[p.first] = expr_to_json(p.second);
181194
}
@@ -195,6 +208,7 @@ struct KernelBuilderSerializerHack {
195208
result["shared_memory"] = expr_to_json(builder.shared_mem_);
196209
result["template_args"] = expr_list_to_json(builder.template_args_);
197210
result["defines"] = std::move(defines);
211+
result["headers"] = std::move(headers);
198212

199213
return result;
200214
}

src/kernel.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,11 @@ KernelBuilder& KernelBuilder::compiler_flag(TypedExpr<std::string> opt) {
5858
return *this;
5959
}
6060

61+
KernelBuilder& KernelBuilder::include_header(KernelSource source) {
62+
preheaders_.push_back(std::move(source));
63+
return *this;
64+
}
65+
6166
KernelDef KernelBuilder::build(
6267
const Config& config,
6368
const std::vector<TypeInfo>& param_types) const {
@@ -86,6 +91,10 @@ KernelDef KernelBuilder::build(
8691
}
8792
}
8893

94+
for (const auto& source : preheaders_) {
95+
def.add_preincluded_header(source);
96+
}
97+
8998
for (const auto& p : defines_) {
9099
def.add_compiler_option("--define-macro");
91100
def.add_compiler_option(p.first + "=" + eval(p.second));

0 commit comments

Comments
 (0)