@@ -181,6 +181,62 @@ __global__ void spaz(int n, const float* input, float* output) {
181181 CHECK (stream.span (spaz.template_params [0 ].name ) == " tiling_factor" );
182182}
183183
184+ /*
185+ ,
186+
187+
188+
189+
190+
191+ */
192+
193+ TEST_CASE (" parser difficult parameters" ) {
194+ std::string input = R"(
195+ #pragma kernel_tuner tune(block_size=32, 64, 128, 256) default(128)
196+ #pragma kernel_tuner problem_size(n)
197+ __global__ void foo(
198+ int n,
199+ float __restrict__* a,
200+ my_type<float> b,
201+ bar::my_type<const float>* c,
202+ bar::my_type<const float, baz::tuple<int, float>> d,
203+ tuple<tuple<tuple<int>, float>, double, tuple<short, tuple<>>> e
204+ ) {
205+ if (threadIdx.x < 10) {
206+ return a[threadIdx.x];
207+ }
208+ }
209+ )" ;
210+
211+ auto stream = internal::TokenStream (" <stdin>" , input);
212+ auto result = extract_annotated_kernels (stream);
213+
214+ const auto & kernels = result.kernels ;
215+ REQUIRE (kernels.size () == 1 );
216+
217+ const auto & foo = kernels[0 ];
218+ CHECK (foo.qualified_name == " foo" );
219+ REQUIRE (foo.fun_params .size () == 6 );
220+
221+ CHECK (stream.span (foo.fun_params [0 ].name ) == " n" );
222+ CHECK (stream.span (foo.fun_params [1 ].name ) == " a" );
223+ CHECK (stream.span (foo.fun_params [2 ].name ) == " b" );
224+ CHECK (stream.span (foo.fun_params [3 ].name ) == " c" );
225+ CHECK (stream.span (foo.fun_params [4 ].name ) == " d" );
226+ CHECK (stream.span (foo.fun_params [5 ].name ) == " e" );
227+
228+ CHECK (foo.fun_params [0 ].type == " int" );
229+ CHECK (foo.fun_params [1 ].type == " float __restrict__*" );
230+ CHECK (foo.fun_params [2 ].type == " my_type<float>" );
231+ CHECK (foo.fun_params [3 ].type == " bar::my_type<const float>*" );
232+ CHECK (
233+ foo.fun_params [4 ].type
234+ == " bar::my_type<const float, baz::tuple<int, float>>" );
235+ CHECK (
236+ foo.fun_params [5 ].type
237+ == " tuple<tuple<tuple<int>, float>, double, tuple<short, tuple<>>>" );
238+ }
239+
184240TEST_CASE (" directives" ) {
185241 std::string input = R"(
186242 namespace bar {
@@ -210,4 +266,4 @@ TEST_CASE("directives") {
210266 KernelSource source (" <stdin>" , result.processed_source );
211267 KernelBuilder builder =
212268 builder_from_annotated_kernel (stream, source, kernels[0 ], {" float" });
213- }
269+ }
0 commit comments