Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 5 additions & 39 deletions third_party/intel/backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,7 @@
from typing import Any, Dict, Tuple
from types import ModuleType
import hashlib
import tempfile
import signal
import os
import subprocess
from pathlib import Path


Expand Down Expand Up @@ -395,42 +392,11 @@ def make_spv(src, metadata, options, device_arch):
metadata["generate_native_code"] = options.generate_native_code

if options.generate_native_code:
with track("generate_native_code"), tempfile.TemporaryDirectory() as temp_dir:
with tempfile.NamedTemporaryFile(mode='wb', suffix='.spv', dir=temp_dir, delete=False) as fsrc:
fsrc.write(spirv)
fbin = fsrc.name + '.o'

ocloc_cmd = [
'ocloc', 'compile', '-file', fsrc.name, '-o', fbin, '-spirv_input', '-device', device_arch,
'-options', metadata["build_flags"] + shader_dump_opt
]

try:
output = subprocess.check_output(ocloc_cmd, stderr=subprocess.STDOUT, text=True)
if 'spilled' in output and metadata["build_flags"].find("-cl-intel-256-GRF-per-thread") == -1:
"""
The exact message is something like:
warning: kernel matmul_kernel compiled SIMD16 allocated 128 regs and spilled around 217
is "spilled" enough for now?
"""
metadata["build_flags"] += " -cl-intel-256-GRF-per-thread"
# re-run with new build flags
ocloc_cmd[-1] = metadata["build_flags"] + shader_dump_opt
subprocess.check_output(ocloc_cmd, stderr=subprocess.STDOUT, text=True)
except subprocess.CalledProcessError as e:
if e.returncode == 255:
error = 'Internal Triton ZEBIN codegen error'
elif e.returncode == 128 + signal.SIGSEGV:
error = '`ocloc` raised SIGSEGV'
else:
error = f'`ocloc` failed with error code {e.returncode}'

raise RuntimeError(f'{error}\n'
f'`ocloc` stderr:\n{e.output}\n'
f'Repro command: {ocloc_cmd}\n') from e

with open(fbin, 'rb') as f:
zebin = f.read()
from triton.runtime.driver import driver
# at this stage the driver is already initialized
device = driver.active.utils.get_current_device()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The triton.compile maybe not target to the current device. The compilation flow should not depends on the runtime. It can be used as cross compiling and AOT compiling.

zebin, n_regs, n_spills, n_max_threads = driver.active.utils.get_native_code(
metadata['name'], spirv, metadata['shared'], metadata['build_flags'] + shader_dump_opt, device)
return zebin
return spirv

Expand Down
118 changes: 118 additions & 0 deletions third_party/intel/backend/driver.c
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,124 @@ extern "C" EXPORT_FUNC PyObject *load_binary(PyObject *args) {
}
}

extern "C" EXPORT_FUNC PyObject *get_native_code(PyObject *args) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Most of this function is taken from the implementation of load_binary function, as I've decided to leave it alone for now to avoid introducing regressions. However, there's potential to reduce code duplication there.

const char *name, *build_flags_ptr;
int shared;
PyObject *py_bytes;
int devId;

if (!PyArg_ParseTuple(args, "sSisi", &name, &py_bytes, &shared,
&build_flags_ptr, &devId)) {
std::cerr << "loadBinary arg parse failed" << std::endl;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
std::cerr << "loadBinary arg parse failed" << std::endl;
std::cerr << "get_native_code arg parse failed" << std::endl;

return NULL;
}

if (devId > g_sycl_l0_device_list.size()) {
std::cerr << "Device is not found " << std::endl;
return NULL;
}

BuildFlags build_flags(build_flags_ptr);

try {
const auto &sycl_l0_device_pair = g_sycl_l0_device_list[devId];
const sycl::device sycl_device = sycl_l0_device_pair.first;
const auto l0_device = sycl_l0_device_pair.second;

const std::string kernel_name = name;
const size_t binary_size = PyBytes_Size(py_bytes);

uint8_t *binary_ptr = (uint8_t *)PyBytes_AsString(py_bytes);
const auto &ctx = get_default_context(sycl_device);
const auto l0_context =
sycl::get_native<sycl::backend::ext_oneapi_level_zero>(ctx);

ze_device_compute_properties_t compute_properties = {};
compute_properties.stype = ZE_STRUCTURE_TYPE_DEVICE_COMPUTE_PROPERTIES;
zeDeviceGetComputeProperties(l0_device, &compute_properties);
int32_t n_max_threads = compute_properties.maxTotalGroupSize;
auto [l0_module, l0_kernel, n_spills] =
compileLevelZeroObjects(binary_ptr, binary_size, kernel_name, l0_device,
l0_context, build_flags(), true);
const bool debugEnabled = getBoolEnv("TRITON_DEBUG");

constexpr int32_t max_reg_spill = 1000;
const bool is_GRF_mode_specified = build_flags.hasGRFSizeFlag();

// If the register mode isn't set, and the number of spills is greater
// than the threshold, recompile the kernel using large GRF mode.
if (!is_GRF_mode_specified && n_spills > max_reg_spill) {
if (debugEnabled)
std::cout << "(I): Detected " << n_spills
<< " spills, recompiling the kernel using large GRF mode"
<< std::endl;

build_flags.addLargeGRFSizeFlag();

try {
auto [l0_module_dgrf, l0_kernel_dgrf, n_spills_dgrf] =
compileLevelZeroObjects(binary_ptr, binary_size, kernel_name,
l0_device, l0_context, build_flags(), true);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
l0_device, l0_context, build_flags(), true);
l0_device, l0_context, build_flags(), true /*is_spv*/);


if (debugEnabled)
std::cout << "(I): Kernel has now " << n_spills_dgrf << " spills"
<< std::endl;

std::swap(l0_module, l0_module_dgrf);
std::swap(l0_kernel, l0_kernel_dgrf);
std::swap(n_spills, n_spills_dgrf);

// clean up the unused module and kernel.
auto error_no = zeKernelDestroy(l0_kernel_dgrf);
if (error_no != ZE_RESULT_SUCCESS) {
std::cerr
<< "[Ignoring] Intel - Error during destroy unused L0 kernel"
<< std::endl;
}
error_no = zeModuleDestroy(l0_module_dgrf);
if (error_no != ZE_RESULT_SUCCESS) {
std::cerr
<< "[Ignoring] Intel - Error during destroy unused L0 module"
<< std::endl;
}
} catch (const std::exception &e) {
std::cerr << "[Ignoring] Error during Intel loadBinary with large "
"registers: "
<< e.what() << std::endl;
// construct previous working version
build_flags = BuildFlags(build_flags_ptr);
}
}

if (debugEnabled && n_spills) {
std::cout << "(I): Detected " << n_spills << " spills for \""
<< kernel_name << "\"" << std::endl;
}

size_t szBinary = 0;
zeModuleGetNativeBinary(l0_module, &szBinary, nullptr);
std::vector<uint8_t> pBinary(szBinary);
zeModuleGetNativeBinary(l0_module, &szBinary, pBinary.data());

PyObject *pyBytes = PyBytes_FromStringAndSize(
reinterpret_cast<const char *>(pBinary.data()),
static_cast<Py_ssize_t>(szBinary));
if (!pyBytes)
return NULL;
return Py_BuildValue("(Niii)", pyBytes, build_flags.n_regs(), n_spills,
n_max_threads);

} catch (const std::exception &e) {
PyGILState_STATE gil_state;
gil_state = PyGILState_Ensure();
PyErr_SetString(PyExc_RuntimeError, e.what());
std::cerr << "Error during Intel get_native_code: " << e.what()
<< std::endl;
PyGILState_Release(gil_state);
return NULL;
}
}

extern "C" EXPORT_FUNC PyObject *init_devices(PyObject *cap) {
void *queue = NULL;
if (!(queue = PyLong_AsVoidPtr(cap)))
Expand Down
7 changes: 6 additions & 1 deletion third_party/intel/backend/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ class SpirvUtils:

def __init__(self, cache_path: str):
self.shared_library = ctypes.PyDLL(cache_path)
methods = ("init_devices", "load_binary", "wait_on_sycl_queue", "has_opencl_extension")
methods = ("init_devices", "load_binary", "get_native_code", "wait_on_sycl_queue", "has_opencl_extension")
for method in methods:
getattr(self.shared_library, method).restype = ctypes.py_object
getattr(self.shared_library, method).argtypes = (ctypes.py_object, )
Expand All @@ -214,6 +214,9 @@ def load_binary(self, *args):
# driver.active.utils.load_binary((self.name, self.kernel, self.metadata.shared, self.metadata.build_flags, device))
return self.shared_library.load_binary(args)

def get_native_code(self, *args):
return self.shared_library.get_native_code(args)

if os.name != 'nt':

def __del__(self):
Expand Down Expand Up @@ -314,8 +317,10 @@ def __init__(self):
# and can cause `Fatal Python error: Segmentation fault`
mod = compile_module_from_src(src=Path(os.path.join(dirname, "driver.c")).read_text(), name="spirv_utils")
self.load_binary = mod.load_binary
self.get_native_code = mod.get_native_code
self.get_device_properties = mod.get_device_properties
self.device_count = mod.init_devices(self.get_sycl_queue())
# breakpoint()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# breakpoint()

self.wait_on_sycl_queue = mod.wait_on_sycl_queue
self.has_opencl_extension = mod.has_opencl_extension

Expand Down
Loading