Skip to content

Commit 1d6482e

Browse files
committed
Generate native code using L0 sdk instead of 'ocloc'
Signed-off-by: Anatoly Myachev <[email protected]>
1 parent 190d245 commit 1d6482e

File tree

3 files changed

+129
-40
lines changed

3 files changed

+129
-40
lines changed

third_party/intel/backend/compiler.py

Lines changed: 5 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,7 @@
99
from typing import Any, Dict, Tuple
1010
from types import ModuleType
1111
import hashlib
12-
import tempfile
13-
import signal
1412
import os
15-
import subprocess
1613
from pathlib import Path
1714

1815

@@ -395,42 +392,11 @@ def make_spv(src, metadata, options, device_arch):
395392
metadata["generate_native_code"] = options.generate_native_code
396393

397394
if options.generate_native_code:
398-
with track("generate_native_code"), tempfile.TemporaryDirectory() as temp_dir:
399-
with tempfile.NamedTemporaryFile(mode='wb', suffix='.spv', dir=temp_dir, delete=False) as fsrc:
400-
fsrc.write(spirv)
401-
fbin = fsrc.name + '.o'
402-
403-
ocloc_cmd = [
404-
'ocloc', 'compile', '-file', fsrc.name, '-o', fbin, '-spirv_input', '-device', device_arch,
405-
'-options', metadata["build_flags"] + shader_dump_opt
406-
]
407-
408-
try:
409-
output = subprocess.check_output(ocloc_cmd, stderr=subprocess.STDOUT, text=True)
410-
if 'spilled' in output and metadata["build_flags"].find("-cl-intel-256-GRF-per-thread") == -1:
411-
"""
412-
The exact message is something like:
413-
warning: kernel matmul_kernel compiled SIMD16 allocated 128 regs and spilled around 217
414-
is "spilled" enough for now?
415-
"""
416-
metadata["build_flags"] += " -cl-intel-256-GRF-per-thread"
417-
# re-run with new build flags
418-
ocloc_cmd[-1] = metadata["build_flags"] + shader_dump_opt
419-
subprocess.check_output(ocloc_cmd, stderr=subprocess.STDOUT, text=True)
420-
except subprocess.CalledProcessError as e:
421-
if e.returncode == 255:
422-
error = 'Internal Triton ZEBIN codegen error'
423-
elif e.returncode == 128 + signal.SIGSEGV:
424-
error = '`ocloc` raised SIGSEGV'
425-
else:
426-
error = f'`ocloc` failed with error code {e.returncode}'
427-
428-
raise RuntimeError(f'{error}\n'
429-
f'`ocloc` stderr:\n{e.output}\n'
430-
f'Repro command: {ocloc_cmd}\n') from e
431-
432-
with open(fbin, 'rb') as f:
433-
zebin = f.read()
395+
from triton.runtime.driver import driver
396+
# at this stage the driver is already initialized
397+
device = driver.active.utils.get_current_device()
398+
zebin, n_regs, n_spills, n_max_threads = driver.active.utils.get_native_code(
399+
metadata['name'], spirv, metadata['shared'], metadata['build_flags'] + shader_dump_opt, device)
434400
return zebin
435401
return spirv
436402

third_party/intel/backend/driver.c

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,124 @@ extern "C" EXPORT_FUNC PyObject *load_binary(PyObject *args) {
323323
}
324324
}
325325

326+
extern "C" EXPORT_FUNC PyObject *get_native_code(PyObject *args) {
327+
const char *name, *build_flags_ptr;
328+
int shared;
329+
PyObject *py_bytes;
330+
int devId;
331+
332+
if (!PyArg_ParseTuple(args, "sSisi", &name, &py_bytes, &shared,
333+
&build_flags_ptr, &devId)) {
334+
std::cerr << "loadBinary arg parse failed" << std::endl;
335+
return NULL;
336+
}
337+
338+
if (devId > g_sycl_l0_device_list.size()) {
339+
std::cerr << "Device is not found " << std::endl;
340+
return NULL;
341+
}
342+
343+
BuildFlags build_flags(build_flags_ptr);
344+
345+
try {
346+
const auto &sycl_l0_device_pair = g_sycl_l0_device_list[devId];
347+
const sycl::device sycl_device = sycl_l0_device_pair.first;
348+
const auto l0_device = sycl_l0_device_pair.second;
349+
350+
const std::string kernel_name = name;
351+
const size_t binary_size = PyBytes_Size(py_bytes);
352+
353+
uint8_t *binary_ptr = (uint8_t *)PyBytes_AsString(py_bytes);
354+
const auto &ctx = get_default_context(sycl_device);
355+
const auto l0_context =
356+
sycl::get_native<sycl::backend::ext_oneapi_level_zero>(ctx);
357+
358+
ze_device_compute_properties_t compute_properties = {};
359+
compute_properties.stype = ZE_STRUCTURE_TYPE_DEVICE_COMPUTE_PROPERTIES;
360+
zeDeviceGetComputeProperties(l0_device, &compute_properties);
361+
int32_t n_max_threads = compute_properties.maxTotalGroupSize;
362+
auto [l0_module, l0_kernel, n_spills] =
363+
compileLevelZeroObjects(binary_ptr, binary_size, kernel_name, l0_device,
364+
l0_context, build_flags(), true);
365+
const bool debugEnabled = getBoolEnv("TRITON_DEBUG");
366+
367+
constexpr int32_t max_reg_spill = 1000;
368+
const bool is_GRF_mode_specified = build_flags.hasGRFSizeFlag();
369+
370+
// If the register mode isn't set, and the number of spills is greater
371+
// than the threshold, recompile the kernel using large GRF mode.
372+
if (!is_GRF_mode_specified && n_spills > max_reg_spill) {
373+
if (debugEnabled)
374+
std::cout << "(I): Detected " << n_spills
375+
<< " spills, recompiling the kernel using large GRF mode"
376+
<< std::endl;
377+
378+
build_flags.addLargeGRFSizeFlag();
379+
380+
try {
381+
auto [l0_module_dgrf, l0_kernel_dgrf, n_spills_dgrf] =
382+
compileLevelZeroObjects(binary_ptr, binary_size, kernel_name,
383+
l0_device, l0_context, build_flags(), true);
384+
385+
if (debugEnabled)
386+
std::cout << "(I): Kernel has now " << n_spills_dgrf << " spills"
387+
<< std::endl;
388+
389+
std::swap(l0_module, l0_module_dgrf);
390+
std::swap(l0_kernel, l0_kernel_dgrf);
391+
std::swap(n_spills, n_spills_dgrf);
392+
393+
// clean up the unused module and kernel.
394+
auto error_no = zeKernelDestroy(l0_kernel_dgrf);
395+
if (error_no != ZE_RESULT_SUCCESS) {
396+
std::cerr
397+
<< "[Ignoring] Intel - Error during destroy unused L0 kernel"
398+
<< std::endl;
399+
}
400+
error_no = zeModuleDestroy(l0_module_dgrf);
401+
if (error_no != ZE_RESULT_SUCCESS) {
402+
std::cerr
403+
<< "[Ignoring] Intel - Error during destroy unused L0 module"
404+
<< std::endl;
405+
}
406+
} catch (const std::exception &e) {
407+
std::cerr << "[Ignoring] Error during Intel loadBinary with large "
408+
"registers: "
409+
<< e.what() << std::endl;
410+
// construct previous working version
411+
build_flags = BuildFlags(build_flags_ptr);
412+
}
413+
}
414+
415+
if (debugEnabled && n_spills) {
416+
std::cout << "(I): Detected " << n_spills << " spills for \""
417+
<< kernel_name << "\"" << std::endl;
418+
}
419+
420+
size_t szBinary = 0;
421+
zeModuleGetNativeBinary(l0_module, &szBinary, nullptr);
422+
std::vector<uint8_t> pBinary(szBinary);
423+
zeModuleGetNativeBinary(l0_module, &szBinary, pBinary.data());
424+
425+
PyObject *pyBytes = PyBytes_FromStringAndSize(
426+
reinterpret_cast<const char *>(pBinary.data()),
427+
static_cast<Py_ssize_t>(szBinary));
428+
if (!pyBytes)
429+
return NULL;
430+
return Py_BuildValue("(Niii)", pyBytes, build_flags.n_regs(), n_spills,
431+
n_max_threads);
432+
433+
} catch (const std::exception &e) {
434+
PyGILState_STATE gil_state;
435+
gil_state = PyGILState_Ensure();
436+
PyErr_SetString(PyExc_RuntimeError, e.what());
437+
std::cerr << "Error during Intel get_native_code: " << e.what()
438+
<< std::endl;
439+
PyGILState_Release(gil_state);
440+
return NULL;
441+
}
442+
}
443+
326444
extern "C" EXPORT_FUNC PyObject *init_devices(PyObject *cap) {
327445
void *queue = NULL;
328446
if (!(queue = PyLong_AsVoidPtr(cap)))

third_party/intel/backend/driver.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ class SpirvUtils:
191191

192192
def __init__(self, cache_path: str):
193193
self.shared_library = ctypes.PyDLL(cache_path)
194-
methods = ("init_devices", "load_binary", "wait_on_sycl_queue", "has_opencl_extension")
194+
methods = ("init_devices", "load_binary", "get_native_code", "wait_on_sycl_queue", "has_opencl_extension")
195195
for method in methods:
196196
getattr(self.shared_library, method).restype = ctypes.py_object
197197
getattr(self.shared_library, method).argtypes = (ctypes.py_object, )
@@ -214,6 +214,9 @@ def load_binary(self, *args):
214214
# driver.active.utils.load_binary((self.name, self.kernel, self.metadata.shared, self.metadata.build_flags, device))
215215
return self.shared_library.load_binary(args)
216216

217+
def get_native_code(self, *args):
218+
return self.shared_library.get_native_code(args)
219+
217220
if os.name != 'nt':
218221

219222
def __del__(self):
@@ -314,8 +317,10 @@ def __init__(self):
314317
# and can cause `Fatal Python error: Segmentation fault`
315318
mod = compile_module_from_src(src=Path(os.path.join(dirname, "driver.c")).read_text(), name="spirv_utils")
316319
self.load_binary = mod.load_binary
320+
self.get_native_code = mod.get_native_code
317321
self.get_device_properties = mod.get_device_properties
318322
self.device_count = mod.init_devices(self.get_sycl_queue())
323+
# breakpoint()
319324
self.wait_on_sycl_queue = mod.wait_on_sycl_queue
320325
self.has_opencl_extension = mod.has_opencl_extension
321326

0 commit comments

Comments
 (0)