@@ -47,6 +47,12 @@ auto read_spirv(const std::string &filename) {
4747 return read_file_as_bytes (filename);
4848}
4949
50+ // Host output tensor buffers and indexes
51+ struct TensorBuffer {
52+ torch::Tensor buffer_ptr;
53+ size_t index;
54+ };
55+
5056// Structure that contains Triton kernel arguments
5157struct KernelArguments {
5258 int gridX;
@@ -62,11 +68,11 @@ struct KernelArguments {
6268 std::string spv_name;
6369 ordered_json jsonData;
6470 std::vector<char *> dev_buffers;
65- torch::Tensor host_outbuffer ;
66- std::string out_tensor_name ;
71+ std::vector<TensorBuffer> host_outbuffers ;
72+ std::vector<std:: string> out_tensor_names ;
6773 std::string spirv_dump_dir;
6874
69- KernelArguments (const std::string &outtensorname ) {
75+ KernelArguments (const std::vector<std:: string> &outtensornames ) {
7076 // Check if the triton_xpu_dump path exists if not point to current
7177 // directory
7278 auto env_path = std::getenv (" TRITON_XPU_DUMP_SPIRV_KERNEL_ARGS" );
@@ -97,7 +103,7 @@ struct KernelArguments {
97103 build_flags = jsonData.at (" build_flags" );
98104 spv_name =
99105 spirv_dump_dir + " /" + jsonData.at (" spv_name" ).get <std::string>();
100- out_tensor_name = outtensorname ;
106+ out_tensor_names = outtensornames ;
101107 }
102108};
103109
@@ -348,13 +354,14 @@ at::TensorOptions getTensorOptions(const std::string &dtype) {
348354 }
349355}
350356
351- at::Tensor launchKernel (sycl::queue stream, sycl::kernel kernel,
352- KernelArguments triton_args, bool get_kernel_time) {
357+ std::vector<TensorBuffer> launchKernel (sycl::queue stream, sycl::kernel kernel,
358+ KernelArguments triton_args,
359+ bool get_kernel_time) {
353360
354361 auto tensor_ptr = [](const torch::Tensor &t) -> void * {
355362 return static_cast <void *>(t.data_ptr ());
356363 };
357- int devout_idx = 0 ;
364+
358365 for (auto &item : triton_args.jsonData [" argument_list" ]) {
359366 if (item.contains (" type" )) {
360367 if (item.at (" type" ).get <std::string>() == " tensor" ) {
@@ -369,36 +376,40 @@ at::Tensor launchKernel(sycl::queue stream, sycl::kernel kernel,
369376 .wait_and_throw ();
370377
371378 // Configure output tensor
372- if (item.at (" name" ).get <std::string>() == triton_args.out_tensor_name ) {
373- devout_idx = triton_args.dev_buffers .size () - 1 ;
374- triton_args.host_outbuffer = torch::zeros (
375- {tensor.sizes ()}, getTensorOptions (item.at (" dtype" )));
376- std::cout << " Tensor output: " << triton_args.host_outbuffer .sizes ()
377- << " , " << triton_args.host_outbuffer .scalar_type () << " ("
378- << triton_args.host_outbuffer .nbytes () << " bytes)"
379- << std::endl;
379+ if (std::find (triton_args.out_tensor_names .begin (),
380+ triton_args.out_tensor_names .end (),
381+ item.at (" name" ).get <std::string>()) !=
382+ triton_args.out_tensor_names .end ()) {
383+ TensorBuffer tb;
384+ tb.buffer_ptr = torch::zeros ({tensor.sizes ()},
385+ getTensorOptions (item.at (" dtype" )));
386+ tb.index = triton_args.dev_buffers .size () - 1 ;
387+ triton_args.host_outbuffers .push_back (tb);
388+ std::cout
389+ << " Tensor output[" << triton_args.host_outbuffers .back ().index
390+ << " ]: " << triton_args.host_outbuffers .back ().buffer_ptr .sizes ()
391+ << " , "
392+ << triton_args.host_outbuffers .back ().buffer_ptr .scalar_type ()
393+ << " (" << triton_args.host_outbuffers .back ().buffer_ptr .nbytes ()
394+ << " bytes)" << std::endl;
380395 }
381396 }
382397 } else {
383398 throw std::runtime_error (" Type entry is missing in JSON argument_list" );
384399 }
385400 }
386401
387- if (!triton_args.host_outbuffer .defined ()) {
388- std::string message = " Output tensor isn't configured; \
389- the second positional parameter is " ;
390- throw std::runtime_error (message + triton_args.out_tensor_name );
391- }
392-
393402 // Launch SYCL kernel
394403 sycl_kernel_launch (stream, kernel, triton_args, get_kernel_time);
395404
396- // copy back
397- stream
398- .memcpy (tensor_ptr (triton_args.host_outbuffer ),
399- triton_args.dev_buffers .at (devout_idx),
400- triton_args.host_outbuffer .nbytes ())
401- .wait_and_throw ();
405+ // copy back the output tensors
406+ for (const auto &item : triton_args.host_outbuffers ) {
407+ stream
408+ .memcpy (tensor_ptr (item.buffer_ptr ),
409+ triton_args.dev_buffers .at (item.index ),
410+ item.buffer_ptr .nbytes ())
411+ .wait_and_throw ();
412+ }
402413
403414 for (auto *dev_ptr : triton_args.dev_buffers ) {
404415 if (dev_ptr)
@@ -407,7 +418,7 @@ at::Tensor launchKernel(sycl::queue stream, sycl::kernel kernel,
407418 throw std::runtime_error (" sycl::free failed \n " );
408419 }
409420
410- return triton_args.host_outbuffer ;
421+ return triton_args.host_outbuffers ;
411422}
412423
413424int main (int argc, char **argv) {
@@ -430,7 +441,7 @@ int main(int argc, char **argv) {
430441 initDevices (&q);
431442
432443 // Parse the JSON file and create argument dictionary
433- KernelArguments tritonArgDict (cliopts.output_tensor );
444+ KernelArguments tritonArgDict (cliopts.output_tensors );
434445
435446 // read spirv
436447 auto spirv = read_spirv (tritonArgDict.spv_name );
@@ -444,12 +455,16 @@ int main(int argc, char **argv) {
444455 std::cout << " Loaded kernel with " << n_regs << " registers and "
445456 << n_spills << " register spills." << std::endl;
446457
447- auto output =
458+ auto output_tensors =
448459 launchKernel (q, kernel, tritonArgDict, cliopts.get_kernel_time );
449460
450- auto output_tensor = tritonArgDict.spirv_dump_dir + " /cpp_outs.pt" ;
451- write_tensor (output_tensor, output);
452- std::cout << " Output Tensor Path: " << output_tensor << std::endl;
461+ // Write output tensors to file
462+ for (auto &item : output_tensors) {
463+ auto output_tensor = tritonArgDict.spirv_dump_dir + " /cpp_outs_" +
464+ std::to_string (item.index ) + " .pt" ;
465+ write_tensor (output_tensor, item.buffer_ptr );
466+ std::cout << " Output Tensor Path: " << output_tensor << std::endl;
467+ }
453468 } catch (const std::runtime_error &e) {
454469 std::cerr << " Error: " << e.what () << std::endl;
455470 return EXIT_FAILURE;
0 commit comments