diff --git a/utils/SPIRVRunner/SPIRVRunner.cpp b/utils/SPIRVRunner/SPIRVRunner.cpp index bdd0157946..f31ffb74c8 100644 --- a/utils/SPIRVRunner/SPIRVRunner.cpp +++ b/utils/SPIRVRunner/SPIRVRunner.cpp @@ -374,10 +374,15 @@ at::Tensor launchKernel(sycl::queue stream, sycl::kernel kernel, int main(int argc, char **argv) { try { + std::string print_output_kernel_tensor = "--print-output-kernel-tensor"; if (argc < 2) { std::cout << "Help: " << std::endl; - std::cout << " \n"; + std::cout << " " << std::endl; std::cout << "./build/SPIRVRunner tensor_2" << std::endl; + std::cout << "To print the output kernel tensor to stdout, use:" + << std::endl; + std::cout << "./build/SPIRVRunner tensor_2 " << print_output_kernel_tensor + << std::endl; throw std::runtime_error("Input arguments are missing \n"); } @@ -405,7 +410,10 @@ int main(int argc, char **argv) { << n_spills << " register spills." << std::endl; auto output = launchKernel(q, kernel, tritonArgDict); - std::cout << "Kernel return output: " << output[0] << std::endl; + + if (argc == 3 && argv[2] == print_output_kernel_tensor) { + std::cout << "Kernel return output: " << output[0] << std::endl; + } auto output_tensor = tritonArgDict.spirv_dump_dir + "/cpp_outs.pt"; write_tensor(output_tensor, output);