-
Couldn't load subscription status.
- Fork 75
Generate native code using L0 sdk instead of ocloc
#5342
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -323,6 +323,124 @@ extern "C" EXPORT_FUNC PyObject *load_binary(PyObject *args) { | |||||
| } | ||||||
| } | ||||||
|
|
||||||
| extern "C" EXPORT_FUNC PyObject *get_native_code(PyObject *args) { | ||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Most of this function is taken from the implementation of |
||||||
| const char *name, *build_flags_ptr; | ||||||
| int shared; | ||||||
| PyObject *py_bytes; | ||||||
| int devId; | ||||||
|
|
||||||
| if (!PyArg_ParseTuple(args, "sSisi", &name, &py_bytes, &shared, | ||||||
| &build_flags_ptr, &devId)) { | ||||||
| std::cerr << "loadBinary arg parse failed" << std::endl; | ||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
| return NULL; | ||||||
| } | ||||||
|
|
||||||
| if (devId > g_sycl_l0_device_list.size()) { | ||||||
| std::cerr << "Device is not found " << std::endl; | ||||||
| return NULL; | ||||||
| } | ||||||
|
|
||||||
| BuildFlags build_flags(build_flags_ptr); | ||||||
|
|
||||||
| try { | ||||||
| const auto &sycl_l0_device_pair = g_sycl_l0_device_list[devId]; | ||||||
| const sycl::device sycl_device = sycl_l0_device_pair.first; | ||||||
| const auto l0_device = sycl_l0_device_pair.second; | ||||||
|
|
||||||
| const std::string kernel_name = name; | ||||||
| const size_t binary_size = PyBytes_Size(py_bytes); | ||||||
|
|
||||||
| uint8_t *binary_ptr = (uint8_t *)PyBytes_AsString(py_bytes); | ||||||
| const auto &ctx = get_default_context(sycl_device); | ||||||
| const auto l0_context = | ||||||
| sycl::get_native<sycl::backend::ext_oneapi_level_zero>(ctx); | ||||||
|
|
||||||
| ze_device_compute_properties_t compute_properties = {}; | ||||||
| compute_properties.stype = ZE_STRUCTURE_TYPE_DEVICE_COMPUTE_PROPERTIES; | ||||||
| zeDeviceGetComputeProperties(l0_device, &compute_properties); | ||||||
| int32_t n_max_threads = compute_properties.maxTotalGroupSize; | ||||||
| auto [l0_module, l0_kernel, n_spills] = | ||||||
| compileLevelZeroObjects(binary_ptr, binary_size, kernel_name, l0_device, | ||||||
| l0_context, build_flags(), true); | ||||||
| const bool debugEnabled = getBoolEnv("TRITON_DEBUG"); | ||||||
|
|
||||||
| constexpr int32_t max_reg_spill = 1000; | ||||||
| const bool is_GRF_mode_specified = build_flags.hasGRFSizeFlag(); | ||||||
|
|
||||||
| // If the register mode isn't set, and the number of spills is greater | ||||||
| // than the threshold, recompile the kernel using large GRF mode. | ||||||
| if (!is_GRF_mode_specified && n_spills > max_reg_spill) { | ||||||
| if (debugEnabled) | ||||||
| std::cout << "(I): Detected " << n_spills | ||||||
| << " spills, recompiling the kernel using large GRF mode" | ||||||
| << std::endl; | ||||||
|
|
||||||
| build_flags.addLargeGRFSizeFlag(); | ||||||
|
|
||||||
| try { | ||||||
| auto [l0_module_dgrf, l0_kernel_dgrf, n_spills_dgrf] = | ||||||
| compileLevelZeroObjects(binary_ptr, binary_size, kernel_name, | ||||||
| l0_device, l0_context, build_flags(), true); | ||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
|
|
||||||
| if (debugEnabled) | ||||||
| std::cout << "(I): Kernel has now " << n_spills_dgrf << " spills" | ||||||
| << std::endl; | ||||||
|
|
||||||
| std::swap(l0_module, l0_module_dgrf); | ||||||
| std::swap(l0_kernel, l0_kernel_dgrf); | ||||||
| std::swap(n_spills, n_spills_dgrf); | ||||||
|
|
||||||
| // clean up the unused module and kernel. | ||||||
| auto error_no = zeKernelDestroy(l0_kernel_dgrf); | ||||||
| if (error_no != ZE_RESULT_SUCCESS) { | ||||||
| std::cerr | ||||||
| << "[Ignoring] Intel - Error during destroy unused L0 kernel" | ||||||
| << std::endl; | ||||||
| } | ||||||
| error_no = zeModuleDestroy(l0_module_dgrf); | ||||||
| if (error_no != ZE_RESULT_SUCCESS) { | ||||||
| std::cerr | ||||||
| << "[Ignoring] Intel - Error during destroy unused L0 module" | ||||||
| << std::endl; | ||||||
| } | ||||||
| } catch (const std::exception &e) { | ||||||
| std::cerr << "[Ignoring] Error during Intel loadBinary with large " | ||||||
| "registers: " | ||||||
| << e.what() << std::endl; | ||||||
| // construct previous working version | ||||||
| build_flags = BuildFlags(build_flags_ptr); | ||||||
| } | ||||||
| } | ||||||
|
|
||||||
| if (debugEnabled && n_spills) { | ||||||
| std::cout << "(I): Detected " << n_spills << " spills for \"" | ||||||
| << kernel_name << "\"" << std::endl; | ||||||
| } | ||||||
|
|
||||||
| size_t szBinary = 0; | ||||||
| zeModuleGetNativeBinary(l0_module, &szBinary, nullptr); | ||||||
| std::vector<uint8_t> pBinary(szBinary); | ||||||
| zeModuleGetNativeBinary(l0_module, &szBinary, pBinary.data()); | ||||||
|
|
||||||
| PyObject *pyBytes = PyBytes_FromStringAndSize( | ||||||
| reinterpret_cast<const char *>(pBinary.data()), | ||||||
| static_cast<Py_ssize_t>(szBinary)); | ||||||
| if (!pyBytes) | ||||||
| return NULL; | ||||||
| return Py_BuildValue("(Niii)", pyBytes, build_flags.n_regs(), n_spills, | ||||||
| n_max_threads); | ||||||
|
|
||||||
| } catch (const std::exception &e) { | ||||||
| PyGILState_STATE gil_state; | ||||||
| gil_state = PyGILState_Ensure(); | ||||||
| PyErr_SetString(PyExc_RuntimeError, e.what()); | ||||||
| std::cerr << "Error during Intel get_native_code: " << e.what() | ||||||
| << std::endl; | ||||||
| PyGILState_Release(gil_state); | ||||||
| return NULL; | ||||||
| } | ||||||
| } | ||||||
|
|
||||||
| extern "C" EXPORT_FUNC PyObject *init_devices(PyObject *cap) { | ||||||
| void *queue = NULL; | ||||||
| if (!(queue = PyLong_AsVoidPtr(cap))) | ||||||
|
|
||||||
| Original file line number | Diff line number | Diff line change | ||
|---|---|---|---|---|
|
|
@@ -191,7 +191,7 @@ class SpirvUtils: | |||
|
|
||||
| def __init__(self, cache_path: str): | ||||
| self.shared_library = ctypes.PyDLL(cache_path) | ||||
| methods = ("init_devices", "load_binary", "wait_on_sycl_queue", "has_opencl_extension") | ||||
| methods = ("init_devices", "load_binary", "get_native_code", "wait_on_sycl_queue", "has_opencl_extension") | ||||
| for method in methods: | ||||
| getattr(self.shared_library, method).restype = ctypes.py_object | ||||
| getattr(self.shared_library, method).argtypes = (ctypes.py_object, ) | ||||
|
|
@@ -214,6 +214,9 @@ def load_binary(self, *args): | |||
| # driver.active.utils.load_binary((self.name, self.kernel, self.metadata.shared, self.metadata.build_flags, device)) | ||||
| return self.shared_library.load_binary(args) | ||||
|
|
||||
| def get_native_code(self, *args): | ||||
| return self.shared_library.get_native_code(args) | ||||
|
|
||||
| if os.name != 'nt': | ||||
|
|
||||
| def __del__(self): | ||||
|
|
@@ -314,8 +317,10 @@ def __init__(self): | |||
| # and can cause `Fatal Python error: Segmentation fault` | ||||
| mod = compile_module_from_src(src=Path(os.path.join(dirname, "driver.c")).read_text(), name="spirv_utils") | ||||
| self.load_binary = mod.load_binary | ||||
| self.get_native_code = mod.get_native_code | ||||
| self.get_device_properties = mod.get_device_properties | ||||
| self.device_count = mod.init_devices(self.get_sycl_queue()) | ||||
| # breakpoint() | ||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||
| self.wait_on_sycl_queue = mod.wait_on_sycl_queue | ||||
| self.has_opencl_extension = mod.has_opencl_extension | ||||
|
|
||||
|
|
||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The triton.compile maybe not target to the current device. The compilation flow should not depends on the runtime. It can be used as cross compiling and AOT compiling.