@@ -242,7 +242,8 @@ void set_argument(sycl::handler &cgh, int index, ordered_json &item) {
242242}
243243
244244static void sycl_kernel_launch (sycl::queue &stream, sycl::kernel &kernel_ptr,
245- KernelArguments triton_args) {
245+ KernelArguments triton_args,
246+ bool get_kernel_time) {
246247 std::string kernel_name =
247248 kernel_ptr.get_info <sycl::info::kernel::function_name>();
248249
@@ -290,7 +291,18 @@ static void sycl_kernel_launch(sycl::queue &stream, sycl::kernel &kernel_ptr,
290291 assert (narg == expected_num_params);
291292 cgh.parallel_for (parallel_work_size, kernel_ptr);
292293 };
293- stream.submit (cgf);
294+ if (get_kernel_time) {
295+ sycl::event event = stream.submit (cgf);
296+ event.wait ();
297+ uint64_t start =
298+ event.get_profiling_info <sycl::info::event_profiling::command_start>();
299+ uint64_t end =
300+ event.get_profiling_info <sycl::info::event_profiling::command_end>();
301+ double duration = static_cast <double >(end - start) / 1000000 ;
302+ std::cout << " Kernel execution time: " << duration << " ms" << std::endl;
303+ } else {
304+ stream.submit (cgf);
305+ }
294306 stream.wait_and_throw ();
295307}
296308
@@ -317,7 +329,7 @@ at::TensorOptions getTensorOptions(const std::string &dtype) {
317329}
318330
319331at::Tensor launchKernel (sycl::queue stream, sycl::kernel kernel,
320- KernelArguments triton_args) {
332+ KernelArguments triton_args, bool get_kernel_time ) {
321333
322334 auto tensor_ptr = [](const torch::Tensor &t) -> void * {
323335 return static_cast <void *>(t.data_ptr ());
@@ -352,8 +364,14 @@ at::Tensor launchKernel(sycl::queue stream, sycl::kernel kernel,
352364 }
353365 }
354366
367+ if (!triton_args.host_outbuffer .defined ()) {
368+ std::string message = " Output tensor isn't configurated; \
369+ the second positional parameter is " ;
370+ throw std::runtime_error (message + triton_args.out_tensor_name );
371+ }
372+
355373 // Launch SYCL kernel
356- sycl_kernel_launch (stream, kernel, triton_args);
374+ sycl_kernel_launch (stream, kernel, triton_args, get_kernel_time );
357375
358376 // copy back
359377 stream
@@ -372,9 +390,24 @@ at::Tensor launchKernel(sycl::queue stream, sycl::kernel kernel,
372390 return triton_args.host_outbuffer ;
373391}
374392
393+ bool check_option_amoung_argv (int argc, char **argv, std::string option) {
394+ bool res = false ;
395+ if (argc > 2 ) {
396+ // optional parameters can be in any order
397+ for (int i = 2 ; i < argc; i++) {
398+ if (argv[i] == option) {
399+ res = true ;
400+ break ;
401+ }
402+ }
403+ }
404+ return res;
405+ }
406+
375407int main (int argc, char **argv) {
376408 try {
377409 std::string print_output_kernel_tensor = " --print-output-kernel-tensor" ;
410+ std::string enable_profiling = " --enable-profiling" ;
378411 if (argc < 2 ) {
379412 std::cout << " Help: " << std::endl;
380413 std::cout << " <Executable> <Output Tensor Name>" << std::endl;
@@ -383,11 +416,22 @@ int main(int argc, char **argv) {
383416 << std::endl;
384417 std::cout << " ./build/SPIRVRunner tensor_2 " << print_output_kernel_tensor
385418 << std::endl;
419+ std::cout << " To get kernel time, use:" << std::endl;
420+ std::cout << " ./build/SPIRVRunner tensor_2 " << enable_profiling
421+ << std::endl;
386422 throw std::runtime_error (" Input arguments are missing \n " );
387423 }
388424
389425 // initialize sycl runtime
390- sycl::queue q = sycl::queue (sycl::gpu_selector_v, exception_handler);
426+ bool get_kernel_time =
427+ check_option_amoung_argv (argc, argv, enable_profiling);
428+ sycl::queue q;
429+ if (get_kernel_time) {
430+ sycl::property_list prop_list{sycl::property::queue::enable_profiling ()};
431+ q = sycl::queue (sycl::gpu_selector_v, exception_handler, prop_list);
432+ } else {
433+ q = sycl::queue (sycl::gpu_selector_v, exception_handler);
434+ }
391435
392436 std::cout << " Running on device: "
393437 << q.get_device ().get_info <sycl::info::device::name>() << " \n " ;
@@ -409,9 +453,9 @@ int main(int argc, char **argv) {
409453 std::cout << " Loaded kernel with " << n_regs << " registers and "
410454 << n_spills << " register spills." << std::endl;
411455
412- auto output = launchKernel (q, kernel, tritonArgDict);
456+ auto output = launchKernel (q, kernel, tritonArgDict, get_kernel_time );
413457
414- if (argc == 3 && argv[ 2 ] == print_output_kernel_tensor) {
458+ if (check_option_amoung_argv ( argc, argv, print_output_kernel_tensor) ) {
415459 std::cout << " Kernel return output: " << output[0 ] << std::endl;
416460 }
417461
0 commit comments