@@ -242,7 +242,8 @@ void set_argument(sycl::handler &cgh, int index, ordered_json &item) {
242242}
243243
244244static void sycl_kernel_launch (sycl::queue &stream, sycl::kernel &kernel_ptr,
245- KernelArguments triton_args) {
245+ KernelArguments triton_args,
246+ bool get_kernel_time) {
246247 std::string kernel_name =
247248 kernel_ptr.get_info <sycl::info::kernel::function_name>();
248249
@@ -290,7 +291,18 @@ static void sycl_kernel_launch(sycl::queue &stream, sycl::kernel &kernel_ptr,
290291 assert (narg == expected_num_params);
291292 cgh.parallel_for (parallel_work_size, kernel_ptr);
292293 };
293- stream.submit (cgf);
294+ if (get_kernel_time) {
295+ sycl::event event = stream.submit (cgf);
296+ event.wait ();
297+ uint64_t start =
298+ event.get_profiling_info <sycl::info::event_profiling::command_start>();
299+ uint64_t end =
300+ event.get_profiling_info <sycl::info::event_profiling::command_end>();
301+ double duration = static_cast <double >(end - start) / 1000000 ;
302+ std::cout << " Kernel execution time: " << duration << " ms" << std::endl;
303+ } else {
304+ stream.submit (cgf);
305+ }
294306 stream.wait_and_throw ();
295307}
296308
@@ -317,7 +329,7 @@ at::TensorOptions getTensorOptions(const std::string &dtype) {
317329}
318330
319331at::Tensor launchKernel (sycl::queue stream, sycl::kernel kernel,
320- KernelArguments triton_args) {
332+ KernelArguments triton_args, bool get_kernel_time ) {
321333
322334 auto tensor_ptr = [](const torch::Tensor &t) -> void * {
323335 return static_cast <void *>(t.data_ptr ());
@@ -352,8 +364,14 @@ at::Tensor launchKernel(sycl::queue stream, sycl::kernel kernel,
352364 }
353365 }
354366
367+ if (!triton_args.host_outbuffer .defined ()) {
368+ std::string message = " Output tensor isn't configured; \
369+ the second positional parameter is " ;
370+ throw std::runtime_error (message + triton_args.out_tensor_name );
371+ }
372+
355373 // Launch SYCL kernel
356- sycl_kernel_launch (stream, kernel, triton_args);
374+ sycl_kernel_launch (stream, kernel, triton_args, get_kernel_time );
357375
358376 // copy back
359377 stream
@@ -372,22 +390,43 @@ at::Tensor launchKernel(sycl::queue stream, sycl::kernel kernel,
372390 return triton_args.host_outbuffer ;
373391}
374392
393+ bool check_option_amoung_argv (int argc, char **argv, std::string option) {
394+ bool res = false ;
395+ if (argc > 2 ) {
396+ // optional parameters can be in any order
397+ for (int i = 2 ; i < argc; i++) {
398+ if (argv[i] == option) {
399+ res = true ;
400+ break ;
401+ }
402+ }
403+ }
404+ return res;
405+ }
406+
375407int main (int argc, char **argv) {
376408 try {
377- std::string print_output_kernel_tensor = " --print-output-kernel-tensor " ;
409+ std::string enable_profiling = " --enable-profiling " ;
378410 if (argc < 2 ) {
379411 std::cout << " Help: " << std::endl;
380412 std::cout << " <Executable> <Output Tensor Name>" << std::endl;
381413 std::cout << " ./build/SPIRVRunner tensor_2" << std::endl;
382- std::cout << " To print the output kernel tensor to stdout, use:"
383- << std::endl;
384- std::cout << " ./build/SPIRVRunner tensor_2 " << print_output_kernel_tensor
414+ std::cout << " To get kernel time, use:" << std::endl;
415+ std::cout << " ./build/SPIRVRunner tensor_2 " << enable_profiling
385416 << std::endl;
386417 throw std::runtime_error (" Input arguments are missing \n " );
387418 }
388419
389420 // initialize sycl runtime
390- sycl::queue q = sycl::queue (sycl::gpu_selector_v, exception_handler);
421+ bool get_kernel_time =
422+ check_option_amoung_argv (argc, argv, enable_profiling);
423+ sycl::queue q;
424+ if (get_kernel_time) {
425+ sycl::property_list prop_list{sycl::property::queue::enable_profiling ()};
426+ q = sycl::queue (sycl::gpu_selector_v, exception_handler, prop_list);
427+ } else {
428+ q = sycl::queue (sycl::gpu_selector_v, exception_handler);
429+ }
391430
392431 std::cout << " Running on device: "
393432 << q.get_device ().get_info <sycl::info::device::name>() << " \n " ;
@@ -409,11 +448,7 @@ int main(int argc, char **argv) {
409448 std::cout << " Loaded kernel with " << n_regs << " registers and "
410449 << n_spills << " register spills." << std::endl;
411450
412- auto output = launchKernel (q, kernel, tritonArgDict);
413-
414- if (argc == 3 && argv[2 ] == print_output_kernel_tensor) {
415- std::cout << " Kernel return output: " << output[0 ] << std::endl;
416- }
451+ auto output = launchKernel (q, kernel, tritonArgDict, get_kernel_time);
417452
418453 auto output_tensor = tritonArgDict.spirv_dump_dir + " /cpp_outs.pt" ;
419454 write_tensor (output_tensor, output);
0 commit comments