Skip to content

Commit d7d55b8

Browse files
authored
Implement arch_parser.c as pure dll (#3230)
Signed-off-by: Anatoly Myachev <[email protected]>
1 parent b287705 commit d7d55b8

File tree

3 files changed

+52
-35
lines changed

3 files changed

+52
-35
lines changed

third_party/intel/backend/arch_parser.c

Lines changed: 8 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -10,18 +10,17 @@
1010

1111
#include <sycl/sycl.hpp>
1212

13-
#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
14-
#include <Python.h>
15-
#include <numpy/arrayobject.h>
16-
17-
static PyObject *parseDeviceArch(PyObject *self, PyObject *args) {
18-
uint64_t dev_arch;
19-
assert(PyArg_ParseTuple(args, "K", &dev_arch) && "Expected an integer");
13+
#if defined(_WIN32)
14+
#define EXPORT_FUNC __declspec(dllexport)
15+
#else
16+
#define EXPORT_FUNC __attribute__((visibility("default")))
17+
#endif
2018

19+
extern "C" EXPORT_FUNC const char *parse_device_arch(uint64_t dev_arch) {
2120
sycl::ext::oneapi::experimental::architecture sycl_arch =
2221
static_cast<sycl::ext::oneapi::experimental::architecture>(dev_arch);
2322
// FIXME: Add support for more architectures.
24-
std::string arch = "";
23+
const char *arch = "";
2524
switch (sycl_arch) {
2625
case sycl::ext::oneapi::experimental::architecture::intel_gpu_pvc:
2726
arch = "pvc";
@@ -39,24 +38,5 @@ static PyObject *parseDeviceArch(PyObject *self, PyObject *args) {
3938
std::cerr << "sycl_arch not recognized: " << (int)sycl_arch << std::endl;
4039
}
4140

42-
return Py_BuildValue("s", arch.c_str());
43-
}
44-
45-
static PyMethodDef ModuleMethods[] = {
46-
{"parse_device_arch", parseDeviceArch, METH_VARARGS,
47-
"parse device architecture"},
48-
{NULL, NULL, 0, NULL} // sentinel
49-
};
50-
51-
static struct PyModuleDef ModuleDef = {PyModuleDef_HEAD_INIT, "arch_utils",
52-
NULL, // documentation
53-
-1, // size
54-
ModuleMethods};
55-
56-
PyMODINIT_FUNC PyInit_arch_utils(void) {
57-
if (PyObject *m = PyModule_Create(&ModuleDef)) {
58-
PyModule_AddFunctions(m, ModuleMethods);
59-
return m;
60-
}
61-
return NULL;
41+
return arch;
6242
}

third_party/intel/backend/compiler.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ def __init__(self, target: tuple) -> None:
132132
raise TypeError("target.arch is not a dict")
133133
dirname = os.path.dirname(os.path.realpath(__file__))
134134
mod = compile_module_from_src(Path(os.path.join(dirname, "arch_parser.c")).read_text(), "arch_utils")
135-
self.parse_device_arch = mod.parse_device_arch
135+
self.device_arch = mod.parse_device_arch(target.arch.get('architecture', 0))
136136
self.properties = self.parse_target(target.arch)
137137
self.binary_ext = "spv"
138138

@@ -155,13 +155,12 @@ def parse_target(self, tgt_prop) -> dict:
155155
dev_prop['has_subgroup_2d_block_io'] = tgt_prop.get('has_subgroup_2d_block_io', False)
156156
dev_prop['has_bfloat16_conversions'] = tgt_prop.get('has_bfloat16_conversions', True)
157157

158-
device_arch = self.parse_device_arch(tgt_prop.get('architecture', 0))
159-
if device_arch and shutil.which('ocloc'):
160-
if device_arch in self.device_props:
161-
dev_prop.update(self.device_props[device_arch])
158+
if self.device_arch and shutil.which('ocloc'):
159+
if self.device_arch in self.device_props:
160+
dev_prop.update(self.device_props[self.device_arch])
162161
return dev_prop
163162
try:
164-
ocloc_cmd = ['ocloc', 'query', 'CL_DEVICE_EXTENSIONS', '-device', device_arch]
163+
ocloc_cmd = ['ocloc', 'query', 'CL_DEVICE_EXTENSIONS', '-device', self.device_arch]
165164
with tempfile.TemporaryDirectory() as temp_dir:
166165
output = subprocess.check_output(ocloc_cmd, text=True, cwd=temp_dir)
167166
supported_extensions = set()
@@ -174,7 +173,7 @@ def parse_target(self, tgt_prop) -> dict:
174173
'has_subgroup_matrix_multiply_accumulate_tensor_float32'] = 'cl_intel_subgroup_matrix_multiply_accumulate_tensor_float32' in supported_extensions
175174
ocloc_dev_prop['has_subgroup_2d_block_io'] = 'cl_intel_subgroup_2d_block_io' in supported_extensions
176175
ocloc_dev_prop['has_bfloat16_conversions'] = 'cl_intel_bfloat16_conversions' in supported_extensions
177-
self.device_props[device_arch] = ocloc_dev_prop
176+
self.device_props[self.device_arch] = ocloc_dev_prop
178177
dev_prop.update(ocloc_dev_prop)
179178
except subprocess.CalledProcessError:
180179
# Note: LTS driver does not support ocloc query CL_DEVICE_EXTENSIONS.

third_party/intel/backend/driver.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import os
33
import hashlib
44
import shutil
5+
import ctypes
56
import sysconfig
67
import tempfile
78
from pathlib import Path
@@ -134,6 +135,39 @@ def libsycl_dir(self) -> str:
134135
COMPILATION_HELPER = CompilationHelper()
135136

136137

138+
class ArchParser:
139+
140+
def __init__(self, cache_path: str):
141+
self.shared_library = ctypes.CDLL(cache_path)
142+
self.shared_library.parse_device_arch.restype = ctypes.c_char_p
143+
self.shared_library.parse_device_arch.argtypes = (ctypes.c_uint64, )
144+
145+
def __getattribute__(self, name):
146+
if name == "parse_device_arch":
147+
shared_library = super().__getattribute__("shared_library")
148+
attr = getattr(shared_library, name)
149+
150+
def wrapper(*args, **kwargs):
151+
return attr(*args, **kwargs).decode("utf-8")
152+
153+
return wrapper
154+
155+
return super().__getattribute__(name)
156+
157+
if os.name != 'nt':
158+
159+
def __del__(self):
160+
handle = self.shared_library._handle
161+
self.shared_library.dlclose.argtypes = (ctypes.c_void_p, )
162+
self.shared_library.dlclose(handle)
163+
else:
164+
165+
def __del__(self):
166+
handle = self.shared_library._handle
167+
ctypes.windll.kernel32.FreeLibrary.argtypes = (ctypes.c_uint64, )
168+
ctypes.windll.kernel32.FreeLibrary(handle)
169+
170+
137171
def compile_module_from_src(src, name):
138172
key = hashlib.sha256(src.encode("utf-8")).hexdigest()
139173
cache = get_cache_manager(key)
@@ -155,6 +189,10 @@ def compile_module_from_src(src, name):
155189
COMPILATION_HELPER.libraries, extra_compile_args=extra_compiler_args)
156190
with open(so, "rb") as f:
157191
cache_path = cache.put(f.read(), file_name, binary=True)
192+
193+
if name == 'arch_utils':
194+
return ArchParser(cache_path)
195+
158196
import importlib.util
159197
spec = importlib.util.spec_from_file_location(name, cache_path)
160198
mod = importlib.util.module_from_spec(spec)

0 commit comments

Comments
 (0)