Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions benchmarks/triton_kernels_benchmark/benchmark_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,7 @@ def serialize_kernel_metadata(arg, args_dict):
args_dict["shared_memory"] = arg.shared
args_dict["kernel_name"] = arg.name
args_dict["spv_name"] = f"{arg.name}.spv"
args_dict["build_flags"] = arg.build_flags


def serialize_args(args, constants, signature):
Expand Down
1 change: 1 addition & 0 deletions third_party/intel/backend/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,7 @@ def serialize_kernel_metadata(arg, args_dict):
args_dict['shared_memory'] = arg.shared
args_dict['kernel_name'] = arg.name
args_dict['spv_name'] = f"{arg.name}.spv"
args_dict['build_flags'] = arg.build_flags


def serialize_args(args, constants, signature):
Expand Down
12 changes: 7 additions & 5 deletions utils/SPIRVRunner/SPIRVRunner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ struct KernelArguments {
int threads_per_warp;
int shared_memory;
std::string kernel_name;
std::string build_flags;
std::string spv_name;
ordered_json jsonData;
std::vector<char *> dev_buffers;
Expand Down Expand Up @@ -94,6 +95,7 @@ struct KernelArguments {
shared_memory = jsonData.at("shared_memory");
threads_per_warp = jsonData.at("threads_per_warp");
kernel_name = jsonData.at("kernel_name");
build_flags = jsonData.at("build_flags");
spv_name =
spirv_dump_dir + "/" + jsonData.at("spv_name").get<std::string>();
out_tensor_name = outtensorname;
Expand Down Expand Up @@ -123,8 +125,9 @@ static inline T checkSyclErrors(const std::tuple<T, ze_result_t> tuple) {
/** SYCL Functions **/
std::tuple<sycl::kernel_bundle<sycl::bundle_state::executable>, sycl::kernel,
int32_t, int32_t>
loadBinary(const std::string &kernel_name, uint8_t *binary_ptr,
const size_t binary_size, const size_t deviceId) {
loadBinary(const std::string &kernel_name, const std::string &build_flags,
uint8_t *binary_ptr, const size_t binary_size,
const size_t deviceId) {
int32_t n_regs = 0;
int32_t n_spills = 0;

Expand All @@ -140,9 +143,8 @@ loadBinary(const std::string &kernel_name, uint8_t *binary_ptr,
sycl::get_native<sycl::backend::ext_oneapi_level_zero>(sycl_device);
const auto l0_context =
sycl::get_native<sycl::backend::ext_oneapi_level_zero>(ctx);
const char *build_flags = "";
auto l0_module = checkSyclErrors(create_module(
l0_context, l0_device, binary_ptr, binary_size, build_flags));
l0_context, l0_device, binary_ptr, binary_size, build_flags.c_str()));
auto l0_kernel = checkSyclErrors(create_function(l0_module, kernel_name));

ze_kernel_properties_t props;
Expand Down Expand Up @@ -395,7 +397,7 @@ int main(int argc, char **argv) {
std::cout << "Read " << spirv.size() << " byte kernel." << std::endl;

auto [kernel_bundle, kernel, n_regs, n_spills] =
loadBinary(tritonArgDict.kernel_name,
loadBinary(tritonArgDict.kernel_name, tritonArgDict.build_flags,
reinterpret_cast<uint8_t *>(spirv.data()), spirv.size(), 0);

// TODO: missing number of registers
Expand Down