From 520525a1933ddeecd5d359902a765817c5e07b29 Mon Sep 17 00:00:00 2001 From: Anatoly Myachev Date: Thu, 24 Oct 2024 11:41:19 +0000 Subject: [PATCH] Dump also 'build_flags' for SPIRVrunner Signed-off-by: Anatoly Myachev --- .../triton_kernels_benchmark/benchmark_driver.py | 1 + third_party/intel/backend/driver.py | 1 + utils/SPIRVRunner/SPIRVRunner.cpp | 12 +++++++----- 3 files changed, 9 insertions(+), 5 deletions(-) diff --git a/benchmarks/triton_kernels_benchmark/benchmark_driver.py b/benchmarks/triton_kernels_benchmark/benchmark_driver.py index 470c6c19e5..2e0a484f06 100644 --- a/benchmarks/triton_kernels_benchmark/benchmark_driver.py +++ b/benchmarks/triton_kernels_benchmark/benchmark_driver.py @@ -405,6 +405,7 @@ def serialize_kernel_metadata(arg, args_dict): args_dict["shared_memory"] = arg.shared args_dict["kernel_name"] = arg.name args_dict["spv_name"] = f"{arg.name}.spv" + args_dict["build_flags"] = arg.build_flags def serialize_args(args, constants, signature): diff --git a/third_party/intel/backend/driver.py b/third_party/intel/backend/driver.py index 92aeb1f44d..485e1dcb91 100644 --- a/third_party/intel/backend/driver.py +++ b/third_party/intel/backend/driver.py @@ -441,6 +441,7 @@ def serialize_kernel_metadata(arg, args_dict): args_dict['shared_memory'] = arg.shared args_dict['kernel_name'] = arg.name args_dict['spv_name'] = f"{arg.name}.spv" + args_dict['build_flags'] = arg.build_flags def serialize_args(args, constants, signature): diff --git a/utils/SPIRVRunner/SPIRVRunner.cpp b/utils/SPIRVRunner/SPIRVRunner.cpp index 651632535e..bdd0157946 100644 --- a/utils/SPIRVRunner/SPIRVRunner.cpp +++ b/utils/SPIRVRunner/SPIRVRunner.cpp @@ -59,6 +59,7 @@ struct KernelArguments { int threads_per_warp; int shared_memory; std::string kernel_name; + std::string build_flags; std::string spv_name; ordered_json jsonData; std::vector dev_buffers; @@ -94,6 +95,7 @@ struct KernelArguments { shared_memory = jsonData.at("shared_memory"); threads_per_warp = jsonData.at("threads_per_warp"); kernel_name = jsonData.at("kernel_name"); + build_flags = jsonData.at("build_flags"); spv_name = spirv_dump_dir + "/" + jsonData.at("spv_name").get(); out_tensor_name = outtensorname; @@ -123,8 +125,9 @@ static inline T checkSyclErrors(const std::tuple tuple) { /** SYCL Functions **/ std::tuple, sycl::kernel, int32_t, int32_t> -loadBinary(const std::string &kernel_name, uint8_t *binary_ptr, - const size_t binary_size, const size_t deviceId) { +loadBinary(const std::string &kernel_name, const std::string &build_flags, + uint8_t *binary_ptr, const size_t binary_size, + const size_t deviceId) { int32_t n_regs = 0; int32_t n_spills = 0; @@ -140,9 +143,8 @@ loadBinary(const std::string &kernel_name, uint8_t *binary_ptr, sycl::get_native(sycl_device); const auto l0_context = sycl::get_native(ctx); - const char *build_flags = ""; auto l0_module = checkSyclErrors(create_module( - l0_context, l0_device, binary_ptr, binary_size, build_flags)); + l0_context, l0_device, binary_ptr, binary_size, build_flags.c_str())); auto l0_kernel = checkSyclErrors(create_function(l0_module, kernel_name)); ze_kernel_properties_t props; @@ -395,7 +397,7 @@ int main(int argc, char **argv) { std::cout << "Read " << spirv.size() << " byte kernel." << std::endl; auto [kernel_bundle, kernel, n_regs, n_spills] = - loadBinary(tritonArgDict.kernel_name, + loadBinary(tritonArgDict.kernel_name, tritonArgDict.build_flags, reinterpret_cast(spirv.data()), spirv.size(), 0); // TODO: missing number of registers