Skip to content

Commit e6c53cf

Browse files
committed
Add support for templated types as parameters in pragma kernels
1 parent ca5e477 commit e6c53cf

File tree

3 files changed

+71
-4
lines changed

3 files changed

+71
-4
lines changed

src/internal/parser.cpp

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,18 @@ static std::vector<FunctionParam> parse_kernel_params(TokenStream& stream) {
4646
Token before_name = begin;
4747
Token name = stream.next();
4848
Token end = stream.peek();
49+
int template_depth = 0;
50+
51+
while (template_depth > 0
52+
|| !(
53+
end.kind == TokenKind::Comma
54+
|| end.kind == TokenKind::ParenR)) {
55+
if (name.kind == TokenKind::AngleL) {
56+
template_depth++;
57+
} else if (name.kind == TokenKind::AngleR && template_depth > 0) {
58+
template_depth--;
59+
}
4960

50-
while (end.kind != TokenKind::Comma && end.kind != TokenKind::ParenR) {
5161
before_name = name;
5262
name = stream.next();
5363
end = stream.peek();

src/internal/tokens.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,9 @@ static index_t advance_string(index_t i, const std::string& input) {
8787
TokenKind char2_to_kind(char a, char b) {
8888
if ((a == '=' && b == '=') || (a == '!' && b == '=')
8989
|| (a == '<' && b == '=') || (a == '>' && b == '=')
90-
|| (a == '&' && b == '&') || (a == '|' && b == '|')
91-
|| (a == '<' && b == '<') || (a == '>' && b == '>')
90+
|| (a == '&' && b == '&')
91+
|| (a == '|' && b == '|')
92+
//|| (a == '<' && b == '<') || (a == '>' && b == '>')
9293
|| (a == ':' && b == ':')) {
9394
return TokenKind::Punct;
9495
}

tests/internal.cpp

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,62 @@ __global__ void spaz(int n, const float* input, float* output) {
181181
CHECK(stream.span(spaz.template_params[0].name) == "tiling_factor");
182182
}
183183

184+
/*
185+
,
186+
187+
188+
189+
190+
191+
*/
192+
193+
TEST_CASE("parser difficult parameters") {
194+
std::string input = R"(
195+
#pragma kernel_tuner tune(block_size=32, 64, 128, 256) default(128)
196+
#pragma kernel_tuner problem_size(n)
197+
__global__ void foo(
198+
int n,
199+
float __restrict__* a,
200+
my_type<float> b,
201+
bar::my_type<const float>* c,
202+
bar::my_type<const float, baz::tuple<int, float>> d,
203+
tuple<tuple<tuple<int>, float>, double, tuple<short, tuple<>>> e
204+
) {
205+
if (threadIdx.x < 10) {
206+
return a[threadIdx.x];
207+
}
208+
}
209+
)";
210+
211+
auto stream = internal::TokenStream("<stdin>", input);
212+
auto result = extract_annotated_kernels(stream);
213+
214+
const auto& kernels = result.kernels;
215+
REQUIRE(kernels.size() == 1);
216+
217+
const auto& foo = kernels[0];
218+
CHECK(foo.qualified_name == "foo");
219+
REQUIRE(foo.fun_params.size() == 6);
220+
221+
CHECK(stream.span(foo.fun_params[0].name) == "n");
222+
CHECK(stream.span(foo.fun_params[1].name) == "a");
223+
CHECK(stream.span(foo.fun_params[2].name) == "b");
224+
CHECK(stream.span(foo.fun_params[3].name) == "c");
225+
CHECK(stream.span(foo.fun_params[4].name) == "d");
226+
CHECK(stream.span(foo.fun_params[5].name) == "e");
227+
228+
CHECK(foo.fun_params[0].type == "int");
229+
CHECK(foo.fun_params[1].type == "float __restrict__*");
230+
CHECK(foo.fun_params[2].type == "my_type<float>");
231+
CHECK(foo.fun_params[3].type == "bar::my_type<const float>*");
232+
CHECK(
233+
foo.fun_params[4].type
234+
== "bar::my_type<const float, baz::tuple<int, float>>");
235+
CHECK(
236+
foo.fun_params[5].type
237+
== "tuple<tuple<tuple<int>, float>, double, tuple<short, tuple<>>>");
238+
}
239+
184240
TEST_CASE("directives") {
185241
std::string input = R"(
186242
namespace bar {
@@ -210,4 +266,4 @@ TEST_CASE("directives") {
210266
KernelSource source("<stdin>", result.processed_source);
211267
KernelBuilder builder =
212268
builder_from_annotated_kernel(stream, source, kernels[0], {"float"});
213-
}
269+
}

0 commit comments

Comments
 (0)