Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion third_party/intel/backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ def make_spv(src, metadata, options):
if os.path.exists(flog.name):
with open(flog.name) as log_file:
log = log_file.read().strip()
if 'spilled' in log:
if 'spilled' in log and metadata["build_flags"].find("-cl-intel-256-GRF-per-thread") == -1:
"""
The exact message is something like:
warning: kernel matmul_kernel compiled SIMD16 allocated 128 regs and spilled around 217
Expand Down
258 changes: 147 additions & 111 deletions third_party/intel/backend/driver.c
Original file line number Diff line number Diff line change
Expand Up @@ -26,25 +26,13 @@ static std::vector<ze_device_handle_t> g_devices;
static std::vector<std::pair<sycl::device, ze_device_handle_t>>
g_sycl_l0_device_list;

static inline void gpuAssert(ze_result_t code) {
if (code != ZE_RESULT_SUCCESS) {
auto str = parseZeResultCode(code);
char err[1024] = {0};
strncat(err, str.c_str(), std::min(str.size(), size_t(1024)));
PyGILState_STATE gil_state;
gil_state = PyGILState_Ensure();
PyErr_SetString(PyExc_RuntimeError, err);
PyGILState_Release(gil_state);
}
}

template <typename T>
static inline T checkSyclErrors(const std::tuple<T, ze_result_t> tuple) {
gpuAssert(std::get<1>(tuple));
if (PyErr_Occurred())
return nullptr;
else
return std::get<0>(tuple);
const auto code = std::get<1>(tuple);
if (code != ZE_RESULT_SUCCESS) {
throw std::runtime_error(parseZeResultCode(code));
}
return std::get<0>(tuple);
}

static PyObject *getDeviceProperties(PyObject *self, PyObject *args) {
Expand Down Expand Up @@ -113,14 +101,79 @@ void freeKernelBundle(PyObject *p) {
PyCapsule_GetPointer(p, "kernel_bundle"));
}

using Spills = int32_t;

template <typename L0_DEVICE, typename L0_CONTEXT>
std::tuple<ze_module_handle_t, ze_kernel_handle_t, Spills>
compileLevelZeroObjects(uint8_t *binary_ptr, const size_t binary_size,
const std::string &kernel_name, L0_DEVICE l0_device,
L0_CONTEXT l0_context, const std::string &build_flags,
const bool is_spv) {
auto l0_module =
checkSyclErrors(create_module(l0_context, l0_device, binary_ptr,
binary_size, build_flags.data(), is_spv));

// Retrieve the kernel properties (e.g. register spills).
auto l0_kernel = checkSyclErrors(create_function(l0_module, kernel_name));

ze_kernel_properties_t props;
props.stype = ZE_STRUCTURE_TYPE_KERNEL_PROPERTIES;
props.pNext = nullptr;
checkSyclErrors(
std::make_tuple(NULL, zeKernelGetProperties(l0_kernel, &props)));

const int32_t n_spills = props.spillMemSize;

return std::make_tuple(l0_module, l0_kernel, n_spills);
}

struct BuildFlags {
std::string build_flags_str;

const char *LARGE_GRF_FLAG{"-cl-intel-256-GRF-per-thread"};
const char *SMALL_GRF_FLAG{"-cl-intel-128-GRF-per-thread"};
const char *AUTO_GRF_FLAG{"-cl-intel-enable-auto-large-GRF-mode"};

BuildFlags(const char *build_flags) : build_flags_str(build_flags) {}

const std::string &operator()() const { return build_flags_str; }

int32_t n_regs() const {
if (build_flags_str.find(LARGE_GRF_FLAG) != std::string::npos) {
return 256;
}
if (build_flags_str.find(SMALL_GRF_FLAG) != std::string::npos) {
return 128;
}
// TODO: arguably we could return 128 if we find no flag instead of 0. For
// now, stick with the conservative choice and alert the user only if a
// specific GRF mode is specified.
return 0;
}

const bool hasGRFSizeFlag() const {
if (build_flags_str.find(LARGE_GRF_FLAG) != std::string::npos ||
build_flags_str.find(SMALL_GRF_FLAG) != std::string::npos ||
build_flags_str.find(AUTO_GRF_FLAG) != std::string::npos) {
return true;
}

return false;
}

void addLargeGRFSizeFlag() {
build_flags_str = build_flags_str.append(" ").append(LARGE_GRF_FLAG);
}
};

static PyObject *loadBinary(PyObject *self, PyObject *args) {
const char *name, *build_flags;
const char *name, *build_flags_ptr;
int shared;
PyObject *py_bytes;
int devId;

if (!PyArg_ParseTuple(args, "sSisi", &name, &py_bytes, &shared, &build_flags,
&devId)) {
if (!PyArg_ParseTuple(args, "sSisi", &name, &py_bytes, &shared,
&build_flags_ptr, &devId)) {
std::cerr << "loadBinary arg parse failed" << std::endl;
return NULL;
}
Expand All @@ -130,106 +183,89 @@ static PyObject *loadBinary(PyObject *self, PyObject *args) {
return NULL;
}

const auto &sycl_l0_device_pair = g_sycl_l0_device_list[devId];
const sycl::device sycl_device = sycl_l0_device_pair.first;
BuildFlags build_flags(build_flags_ptr);

std::string kernel_name = name;
const size_t binary_size = PyBytes_Size(py_bytes);
try {

uint8_t *binary_ptr = (uint8_t *)PyBytes_AsString(py_bytes);
const auto ctx = sycl_device.get_platform().ext_oneapi_get_default_context();
const auto l0_device =
sycl::get_native<sycl::backend::ext_oneapi_level_zero>(sycl_device);
const auto l0_context =
sycl::get_native<sycl::backend::ext_oneapi_level_zero>(ctx);
const auto &sycl_l0_device_pair = g_sycl_l0_device_list[devId];
const sycl::device sycl_device = sycl_l0_device_pair.first;

const auto use_native_code =
isEnvValueBool(getStrEnv("TRITON_XPU_GEN_NATIVE_CODE"));
const bool is_spv = use_native_code ? !(*use_native_code) : true;
const std::string kernel_name = name;
const size_t binary_size = PyBytes_Size(py_bytes);

auto l0_module = checkSyclErrors(create_module(
l0_context, l0_device, binary_ptr, binary_size, build_flags, is_spv));
uint8_t *binary_ptr = (uint8_t *)PyBytes_AsString(py_bytes);
const auto ctx =
sycl_device.get_platform().ext_oneapi_get_default_context();
const auto l0_device =
sycl::get_native<sycl::backend::ext_oneapi_level_zero>(sycl_device);
const auto l0_context =
sycl::get_native<sycl::backend::ext_oneapi_level_zero>(ctx);

auto checkL0Errors = [&](auto l0_module) -> ze_kernel_handle_t {
if (PyErr_Occurred()) {
// check for errors from module creation
return NULL;
}
ze_kernel_handle_t l0_kernel =
checkSyclErrors(create_function(l0_module, kernel_name));
if (PyErr_Occurred()) {
// check for errors from kernel creation
return NULL;
}
return l0_kernel;
};
const auto use_native_code =
isEnvValueBool(getStrEnv("TRITON_XPU_GEN_NATIVE_CODE"));
const bool is_spv = use_native_code ? !(*use_native_code) : true;

// Retrieve the kernel properties (e.g. register spills).
ze_kernel_handle_t l0_kernel = checkL0Errors(l0_module);
ze_kernel_properties_t props;
props.stype = ZE_STRUCTURE_TYPE_KERNEL_PROPERTIES;
props.pNext = nullptr;
gpuAssert(zeKernelGetProperties(l0_kernel, &props));

int32_t n_spills = props.spillMemSize;
const int32_t n_regs = 0;

if (is_spv) {
constexpr int32_t max_reg_spill = 1000;
std::string build_flags_str(build_flags);
bool is_GRF_mode_specified = false;

// Check whether the GRF mode is specified by the build flags.
if (build_flags_str.find("-cl-intel-256-GRF-per-thread") !=
std::string::npos ||
build_flags_str.find("-cl-intel-128-GRF-per-thread") !=
std::string::npos ||
build_flags_str.find("-cl-intel-enable-auto-large-GRF-mode") !=
std::string::npos) {
is_GRF_mode_specified = true;
}
auto [l0_module, l0_kernel, n_spills] =
compileLevelZeroObjects(binary_ptr, binary_size, kernel_name, l0_device,
l0_context, build_flags(), is_spv);

if (is_spv) {
constexpr int32_t max_reg_spill = 1000;
const bool is_GRF_mode_specified = build_flags.hasGRFSizeFlag();

// If the register mode isn't set, and the number of spills is greater
// than the threshold, recompile the kernel using large GRF mode.
if (!is_GRF_mode_specified && n_spills > max_reg_spill) {
const std::optional<bool> debugEnabled =
isEnvValueBool(getStrEnv("TRITON_DEBUG"));
if (debugEnabled)
std::cout << "(I): Detected " << n_spills
<< " spills, recompiling the kernel using large GRF mode"
<< std::endl;

build_flags.addLargeGRFSizeFlag();

// If the register mode isn't set, and the number of spills is greater
// than the threshold, recompile the kernel using large GRF mode.
if (!is_GRF_mode_specified && n_spills > max_reg_spill) {
const std::optional<bool> debugEnabled =
isEnvValueBool(getStrEnv("TRITON_DEBUG"));
if (debugEnabled)
std::cout << "(I): Detected " << n_spills
<< " spills, recompiling kernel \"" << kernel_name
<< "\" using large GRF mode" << std::endl;

const std::string new_build_flags =
build_flags_str.append(" -cl-intel-256-GRF-per-thread");
l0_module = checkSyclErrors(
create_module(l0_context, l0_device, binary_ptr, binary_size,
new_build_flags.c_str(), is_spv));

l0_kernel = checkL0Errors(l0_module);
gpuAssert(zeKernelGetProperties(l0_kernel, &props));
n_spills = props.spillMemSize;

if (debugEnabled)
std::cout << "(I): Kernel has now " << n_spills << " spills"
<< std::endl;
auto [l0_module, l0_kernel, n_spills] = compileLevelZeroObjects(
binary_ptr, binary_size, kernel_name, l0_device, l0_context,
build_flags(), is_spv);

if (debugEnabled)
std::cout << "(I): Kernel has now " << n_spills << " spills"
<< std::endl;
}
}
}

auto mod = new sycl::kernel_bundle<sycl::bundle_state::executable>(
sycl::make_kernel_bundle<sycl::backend::ext_oneapi_level_zero,
sycl::bundle_state::executable>(
{l0_module, sycl::ext::oneapi::level_zero::ownership::transfer},
ctx));
sycl::kernel *fun =
new sycl::kernel(sycl::make_kernel<sycl::backend::ext_oneapi_level_zero>(
{*mod, l0_kernel, sycl::ext::oneapi::level_zero::ownership::transfer},
ctx));
auto kernel_py =
PyCapsule_New(reinterpret_cast<void *>(fun), "kernel", freeKernel);
auto kernel_bundle_py = PyCapsule_New(reinterpret_cast<void *>(mod),
"kernel_bundle", freeKernelBundle);

return Py_BuildValue("(OOii)", kernel_bundle_py, kernel_py, n_regs, n_spills);
auto n_regs = build_flags.n_regs();

auto mod = new sycl::kernel_bundle<sycl::bundle_state::executable>(
sycl::make_kernel_bundle<sycl::backend::ext_oneapi_level_zero,
sycl::bundle_state::executable>(
{l0_module, sycl::ext::oneapi::level_zero::ownership::transfer},
ctx));
sycl::kernel *fun = new sycl::kernel(
sycl::make_kernel<sycl::backend::ext_oneapi_level_zero>(
{*mod, l0_kernel,
sycl::ext::oneapi::level_zero::ownership::transfer},
ctx));
auto kernel_py =
PyCapsule_New(reinterpret_cast<void *>(fun), "kernel", freeKernel);
auto kernel_bundle_py = PyCapsule_New(reinterpret_cast<void *>(mod),
"kernel_bundle", freeKernelBundle);

return Py_BuildValue("(OOii)", kernel_bundle_py, kernel_py, n_regs,
n_spills);

} catch (const std::exception &e) {
char err[1024] = {0};
std::string_view error_str(e.what());
strncat(err, error_str.data(), std::min(error_str.size(), size_t(1024)));
PyGILState_STATE gil_state;
gil_state = PyGILState_Ensure();
PyErr_SetString(PyExc_RuntimeError, err);
std::cerr << "Error during Intel loadBinary: " << err << std::endl;
PyGILState_Release(gil_state);
return NULL;
}
}

static PyObject *initContext(PyObject *self, PyObject *args) {
Expand Down