Skip to content

Commit e25a11f

Browse files
committed
Add support restrictions containing kernel arguments
1 parent adb0e61 commit e25a11f

File tree

4 files changed

+74
-7
lines changed

4 files changed

+74
-7
lines changed

include/kernel_launcher/builder.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#define KERNEL_LAUNCHER_BUILDER_H
33

44
#include <unordered_map>
5+
#include <utility>
56
#include <vector>
67

78
#include "kernel_launcher/arg.h"
@@ -24,11 +25,13 @@ struct KernelInstance {
2425
CudaModule module,
2526
std::array<TypedExpr<uint32_t>, 3> block_size,
2627
std::array<TypedExpr<uint32_t>, 3> grid_size,
27-
TypedExpr<uint32_t> shared_mem) :
28+
TypedExpr<uint32_t> shared_mem = 0,
29+
std::vector<TypedExpr<bool>> assertions = {}) :
2830
module_(std::move(module)),
2931
block_size_(std::move(block_size)),
3032
grid_size_(std::move(grid_size)),
31-
shared_mem_(std::move(shared_mem)) {}
33+
shared_mem_(std::move(shared_mem)),
34+
assertions_(std::move(assertions)) {}
3235

3336
void launch(
3437
cudaStream_t stream,
@@ -52,6 +55,7 @@ struct KernelInstance {
5255
std::array<TypedExpr<uint32_t>, 3> block_size_ = {1, 1, 1};
5356
std::array<TypedExpr<uint32_t>, 3> grid_size_ = {0, 0, 0};
5457
TypedExpr<uint32_t> shared_mem_ = 0;
58+
std::vector<TypedExpr<bool>> assertions_;
5559
};
5660

5761
struct KernelBuilderSerializerHack;

include/kernel_launcher/config.h

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -203,18 +203,33 @@ struct ConfigSpace {
203203
*/
204204
void restriction(TypedExpr<bool> e);
205205

206+
/**
207+
* Returns the restrictions added by `restriction()`
208+
*/
209+
const std::vector<TypedExpr<bool>>& restrictions() const {
210+
return restrictions_;
211+
}
212+
206213
/**
207214
* Returns the default configuration for this configuration space.
208215
*/
209216
Config default_config() const;
210217

211218
/**
212-
* Check if the given configuration is a valid member of this configuration
213-
* space. This method essentially checks three things:
219+
* Check if the given configuration is a member of this configuration
220+
* space. This method essentially checks two things:
214221
*
215222
* * Does the configuration contain the correct parameters.
216223
* * Do these parameter contain valid values.
217-
* * Does the configuration meet the restrictions.
224+
*
225+
* However, it does _not_ check if the configuration satisfies the
226+
* restrictions.
227+
*/
228+
bool contains(const Eval& config) const;
229+
230+
/**
231+
* Check if the given configuration is a valid member of this configuration
232+
* space and also meets the restrictions.
218233
*/
219234
bool is_valid(const Eval& config) const;
220235

src/builder.cpp

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,18 @@ void KernelInstance::launch(
131131
}
132132
}
133133

134+
// We check for assertions now after printing the debug information. This
135+
// allows one to check the debugging output to see what where the arguments
136+
// provided to kernel that caused the assertion to fail.
137+
for (const auto& assertion : assertions_) {
138+
if (!eval(assertion)) {
139+
std::stringstream ss;
140+
ss << "failed to launch kernel `" << module_.function_name()
141+
<< "`, assertion failed: `" << assertion.to_string() << "`";
142+
throw std::runtime_error(ss.str());
143+
}
144+
}
145+
134146
module_.launch(stream, grid_size, block_size, smem, ptrs.data());
135147
}
136148

@@ -444,7 +456,34 @@ KernelInstance KernelBuilder::compile(
444456
const std::vector<TypeInfo>& param_types,
445457
const ICompiler& compiler,
446458
CudaContextHandle ctx) const {
459+
if (!contains(config)) {
460+
std::stringstream ss;
461+
ss << "invalid configuration: `" << config << "`";
462+
throw std::runtime_error(ss.str());
463+
}
464+
447465
DeviceAttrEval eval = {ctx.device(), config};
466+
std::vector<TypedExpr<bool>> assertions;
467+
468+
for (const auto& restriction : restrictions()) {
469+
auto r = restriction.resolve(eval);
470+
471+
if (!r.is_constant()) {
472+
// Any restriction that contain kernel arguments cannot be resolved
473+
// now at this moment. We add these to the list of assertions
474+
// that will be checked each time the kernel gets launched.
475+
assertions.emplace_back(r);
476+
continue;
477+
}
478+
479+
if (!r.eval(eval)) {
480+
std::stringstream ss;
481+
ss << "configuration `" << config
482+
<< "` does not meet the following restriction: `"
483+
<< restriction.to_string() << "`";
484+
throw std::runtime_error(ss.str());
485+
}
486+
}
448487

449488
if (!is_valid(eval)) {
450489
std::stringstream ss;
@@ -469,7 +508,8 @@ KernelInstance KernelBuilder::compile(
469508
std::move(module),
470509
std::move(block_size),
471510
std::move(grid_size),
472-
shared_mem};
511+
std::move(shared_mem),
512+
std::move(assertions)};
473513
}
474514

475515
} // namespace kernel_launcher

src/config.cpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ Config ConfigSpace::default_config() const {
111111
return config;
112112
}
113113

114-
bool ConfigSpace::is_valid(const Eval& config) const {
114+
bool ConfigSpace::contains(const Eval& config) const {
115115
for (const auto& p : params_) {
116116
Value v;
117117

@@ -124,6 +124,14 @@ bool ConfigSpace::is_valid(const Eval& config) const {
124124
}
125125
}
126126

127+
return true;
128+
}
129+
130+
bool ConfigSpace::is_valid(const Eval& config) const {
131+
if (!contains(config)) {
132+
return false;
133+
}
134+
127135
for (const auto& r : restrictions_) {
128136
if (!config(r)) {
129137
return false;

0 commit comments

Comments
 (0)