Skip to content

Commit b9e4be1

Browse files
authored
Dump also build_flags for SPIRVrunner (#2554)
Closes #2552 Signed-off-by: Anatoly Myachev <[email protected]>
1 parent a5140b7 commit b9e4be1

File tree

3 files changed

+9
-5
lines changed

3 files changed

+9
-5
lines changed

benchmarks/triton_kernels_benchmark/benchmark_driver.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -405,6 +405,7 @@ def serialize_kernel_metadata(arg, args_dict):
405405
args_dict["shared_memory"] = arg.shared
406406
args_dict["kernel_name"] = arg.name
407407
args_dict["spv_name"] = f"{arg.name}.spv"
408+
args_dict["build_flags"] = arg.build_flags
408409

409410

410411
def serialize_args(args, constants, signature):

third_party/intel/backend/driver.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -441,6 +441,7 @@ def serialize_kernel_metadata(arg, args_dict):
441441
args_dict['shared_memory'] = arg.shared
442442
args_dict['kernel_name'] = arg.name
443443
args_dict['spv_name'] = f"{arg.name}.spv"
444+
args_dict['build_flags'] = arg.build_flags
444445

445446

446447
def serialize_args(args, constants, signature):

utils/SPIRVRunner/SPIRVRunner.cpp

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ struct KernelArguments {
5959
int threads_per_warp;
6060
int shared_memory;
6161
std::string kernel_name;
62+
std::string build_flags;
6263
std::string spv_name;
6364
ordered_json jsonData;
6465
std::vector<char *> dev_buffers;
@@ -94,6 +95,7 @@ struct KernelArguments {
9495
shared_memory = jsonData.at("shared_memory");
9596
threads_per_warp = jsonData.at("threads_per_warp");
9697
kernel_name = jsonData.at("kernel_name");
98+
build_flags = jsonData.at("build_flags");
9799
spv_name =
98100
spirv_dump_dir + "/" + jsonData.at("spv_name").get<std::string>();
99101
out_tensor_name = outtensorname;
@@ -123,8 +125,9 @@ static inline T checkSyclErrors(const std::tuple<T, ze_result_t> tuple) {
123125
/** SYCL Functions **/
124126
std::tuple<sycl::kernel_bundle<sycl::bundle_state::executable>, sycl::kernel,
125127
int32_t, int32_t>
126-
loadBinary(const std::string &kernel_name, uint8_t *binary_ptr,
127-
const size_t binary_size, const size_t deviceId) {
128+
loadBinary(const std::string &kernel_name, const std::string &build_flags,
129+
uint8_t *binary_ptr, const size_t binary_size,
130+
const size_t deviceId) {
128131
int32_t n_regs = 0;
129132
int32_t n_spills = 0;
130133

@@ -140,9 +143,8 @@ loadBinary(const std::string &kernel_name, uint8_t *binary_ptr,
140143
sycl::get_native<sycl::backend::ext_oneapi_level_zero>(sycl_device);
141144
const auto l0_context =
142145
sycl::get_native<sycl::backend::ext_oneapi_level_zero>(ctx);
143-
const char *build_flags = "";
144146
auto l0_module = checkSyclErrors(create_module(
145-
l0_context, l0_device, binary_ptr, binary_size, build_flags));
147+
l0_context, l0_device, binary_ptr, binary_size, build_flags.c_str()));
146148
auto l0_kernel = checkSyclErrors(create_function(l0_module, kernel_name));
147149

148150
ze_kernel_properties_t props;
@@ -395,7 +397,7 @@ int main(int argc, char **argv) {
395397
std::cout << "Read " << spirv.size() << " byte kernel." << std::endl;
396398

397399
auto [kernel_bundle, kernel, n_regs, n_spills] =
398-
loadBinary(tritonArgDict.kernel_name,
400+
loadBinary(tritonArgDict.kernel_name, tritonArgDict.build_flags,
399401
reinterpret_cast<uint8_t *>(spirv.data()), spirv.size(), 0);
400402

401403
// TODO: missing number of registers

0 commit comments

Comments
 (0)