Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 49 additions & 14 deletions utils/SPIRVRunner/SPIRVRunner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,8 @@ void set_argument(sycl::handler &cgh, int index, ordered_json &item) {
}

static void sycl_kernel_launch(sycl::queue &stream, sycl::kernel &kernel_ptr,
KernelArguments triton_args) {
KernelArguments triton_args,
bool get_kernel_time) {
std::string kernel_name =
kernel_ptr.get_info<sycl::info::kernel::function_name>();

Expand Down Expand Up @@ -290,7 +291,18 @@ static void sycl_kernel_launch(sycl::queue &stream, sycl::kernel &kernel_ptr,
assert(narg == expected_num_params);
cgh.parallel_for(parallel_work_size, kernel_ptr);
};
stream.submit(cgf);
if (get_kernel_time) {
sycl::event event = stream.submit(cgf);
event.wait();
uint64_t start =
event.get_profiling_info<sycl::info::event_profiling::command_start>();
uint64_t end =
event.get_profiling_info<sycl::info::event_profiling::command_end>();
double duration = static_cast<double>(end - start) / 1000000;
std::cout << "Kernel execution time: " << duration << " ms" << std::endl;
} else {
stream.submit(cgf);
}
stream.wait_and_throw();
}

Expand All @@ -317,7 +329,7 @@ at::TensorOptions getTensorOptions(const std::string &dtype) {
}

at::Tensor launchKernel(sycl::queue stream, sycl::kernel kernel,
KernelArguments triton_args) {
KernelArguments triton_args, bool get_kernel_time) {

auto tensor_ptr = [](const torch::Tensor &t) -> void * {
return static_cast<void *>(t.data_ptr());
Expand Down Expand Up @@ -352,8 +364,14 @@ at::Tensor launchKernel(sycl::queue stream, sycl::kernel kernel,
}
}

if (!triton_args.host_outbuffer.defined()) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To avoid getting an incomprehensible error.

std::string message = "Output tensor isn't configured; \
the second positional parameter is ";
throw std::runtime_error(message + triton_args.out_tensor_name);
}

// Launch SYCL kernel
sycl_kernel_launch(stream, kernel, triton_args);
sycl_kernel_launch(stream, kernel, triton_args, get_kernel_time);

// copy back
stream
Expand All @@ -372,22 +390,43 @@ at::Tensor launchKernel(sycl::queue stream, sycl::kernel kernel,
return triton_args.host_outbuffer;
}

bool check_option_amoung_argv(int argc, char **argv, std::string option) {
bool res = false;
if (argc > 2) {
// optional parameters can be in any order
for (int i = 2; i < argc; i++) {
if (argv[i] == option) {
res = true;
break;
}
}
}
return res;
}

int main(int argc, char **argv) {
try {
std::string print_output_kernel_tensor = "--print-output-kernel-tensor";
std::string enable_profiling = "--enable-profiling";
if (argc < 2) {
std::cout << "Help: " << std::endl;
std::cout << "<Executable> <Output Tensor Name>" << std::endl;
std::cout << "./build/SPIRVRunner tensor_2" << std::endl;
std::cout << "To print the output kernel tensor to stdout, use:"
<< std::endl;
std::cout << "./build/SPIRVRunner tensor_2 " << print_output_kernel_tensor
std::cout << "To get kernel time, use:" << std::endl;
std::cout << "./build/SPIRVRunner tensor_2 " << enable_profiling
<< std::endl;
throw std::runtime_error("Input arguments are missing \n");
}

// initialize sycl runtime
sycl::queue q = sycl::queue(sycl::gpu_selector_v, exception_handler);
bool get_kernel_time =
check_option_amoung_argv(argc, argv, enable_profiling);
sycl::queue q;
if (get_kernel_time) {
sycl::property_list prop_list{sycl::property::queue::enable_profiling()};
q = sycl::queue(sycl::gpu_selector_v, exception_handler, prop_list);
} else {
q = sycl::queue(sycl::gpu_selector_v, exception_handler);
}

std::cout << "Running on device: "
<< q.get_device().get_info<sycl::info::device::name>() << "\n";
Expand All @@ -409,11 +448,7 @@ int main(int argc, char **argv) {
std::cout << "Loaded kernel with " << n_regs << " registers and "
<< n_spills << " register spills." << std::endl;

auto output = launchKernel(q, kernel, tritonArgDict);

if (argc == 3 && argv[2] == print_output_kernel_tensor) {
std::cout << "Kernel return output: " << output[0] << std::endl;
}
auto output = launchKernel(q, kernel, tritonArgDict, get_kernel_time);

auto output_tensor = tritonArgDict.spirv_dump_dir + "/cpp_outs.pt";
write_tensor(output_tensor, output);
Expand Down