@@ -59,6 +59,7 @@ struct KernelArguments {
5959 int threads_per_warp;
6060 int shared_memory;
6161 std::string kernel_name;
62+ std::string build_flags;
6263 std::string spv_name;
6364 ordered_json jsonData;
6465 std::vector<char *> dev_buffers;
@@ -94,6 +95,7 @@ struct KernelArguments {
9495 shared_memory = jsonData.at (" shared_memory" );
9596 threads_per_warp = jsonData.at (" threads_per_warp" );
9697 kernel_name = jsonData.at (" kernel_name" );
98+ build_flags = jsonData.at (" build_flags" );
9799 spv_name =
98100 spirv_dump_dir + " /" + jsonData.at (" spv_name" ).get <std::string>();
99101 out_tensor_name = outtensorname;
@@ -123,8 +125,9 @@ static inline T checkSyclErrors(const std::tuple<T, ze_result_t> tuple) {
123125/* * SYCL Functions **/
124126std::tuple<sycl::kernel_bundle<sycl::bundle_state::executable>, sycl::kernel,
125127 int32_t , int32_t >
126- loadBinary (const std::string &kernel_name, uint8_t *binary_ptr,
127- const size_t binary_size, const size_t deviceId) {
128+ loadBinary (const std::string &kernel_name, const std::string &build_flags,
129+ uint8_t *binary_ptr, const size_t binary_size,
130+ const size_t deviceId) {
128131 int32_t n_regs = 0 ;
129132 int32_t n_spills = 0 ;
130133
@@ -140,9 +143,8 @@ loadBinary(const std::string &kernel_name, uint8_t *binary_ptr,
140143 sycl::get_native<sycl::backend::ext_oneapi_level_zero>(sycl_device);
141144 const auto l0_context =
142145 sycl::get_native<sycl::backend::ext_oneapi_level_zero>(ctx);
143- const char *build_flags = " " ;
144146 auto l0_module = checkSyclErrors (create_module (
145- l0_context, l0_device, binary_ptr, binary_size, build_flags));
147+ l0_context, l0_device, binary_ptr, binary_size, build_flags. c_str () ));
146148 auto l0_kernel = checkSyclErrors (create_function (l0_module, kernel_name));
147149
148150 ze_kernel_properties_t props;
@@ -395,7 +397,7 @@ int main(int argc, char **argv) {
395397 std::cout << " Read " << spirv.size () << " byte kernel." << std::endl;
396398
397399 auto [kernel_bundle, kernel, n_regs, n_spills] =
398- loadBinary (tritonArgDict.kernel_name ,
400+ loadBinary (tritonArgDict.kernel_name , tritonArgDict. build_flags ,
399401 reinterpret_cast <uint8_t *>(spirv.data ()), spirv.size (), 0 );
400402
401403 // TODO: missing number of registers
0 commit comments