Skip to content

Commit 21933fb

Browse files
authored
[NFI]: Update driver.py (#4819)
Making this file implementation closer to the upstream one. Signed-off-by: Tiotto, Ettore <[email protected]>
1 parent 4d11eb7 commit 21933fb

File tree

2 files changed

+56
-58
lines changed

2 files changed

+56
-58
lines changed

third_party/intel/backend/compiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def __init__(self, target: tuple) -> None:
114114
if not isinstance(target.arch, dict):
115115
raise TypeError("target.arch is not a dict")
116116
dirname = os.path.dirname(os.path.realpath(__file__))
117-
mod = compile_module_from_src(Path(os.path.join(dirname, "arch_parser.c")).read_text(), "arch_utils")
117+
mod = compile_module_from_src(src=Path(os.path.join(dirname, "arch_parser.c")).read_text(), name="arch_utils")
118118
self.device_arch = knobs.intel.device_arch or mod.parse_device_arch(target.arch.get('architecture', 0))
119119
self.properties = self.parse_target(target.arch)
120120
self.binary_ext = "spv"

third_party/intel/backend/driver.py

Lines changed: 55 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from functools import cached_property
1010

1111
from triton import knobs
12-
from triton.runtime.build import _build, platform_key
12+
from triton.runtime.build import _build, platform_key, _load_module_from_path
1313
from triton.runtime.cache import get_cache_manager
1414
from triton.backends.compiler import GPUTarget
1515
from triton.backends.driver import DriverBase
@@ -252,7 +252,7 @@ def __del__(self):
252252
ctypes.windll.kernel32.FreeLibrary(handle)
253253

254254

255-
def compile_module_from_src(src, name):
255+
def compile_module_from_src(src: str, name: str):
256256
hasher = hashlib.sha256(__CACHE_VERSION.encode("utf-8"))
257257
hasher.update((src + platform_key()).encode("utf-8"))
258258
key = hasher.hexdigest()
@@ -278,18 +278,14 @@ def compile_module_from_src(src, name):
278278

279279
if name == 'arch_utils':
280280
return ArchParser(cache_path)
281-
elif name == 'spirv_utils':
281+
if name == 'spirv_utils':
282282
return SpirvUtils(cache_path)
283-
elif name == '__triton_launcher':
283+
if name == '__triton_launcher':
284284
return TritonLauncher(cache_path)
285-
elif name == 'proton_utils':
285+
if name == 'proton_utils':
286286
return cache_path
287287

288-
import importlib.util
289-
spec = importlib.util.spec_from_file_location(name, cache_path)
290-
mod = importlib.util.module_from_spec(spec)
291-
spec.loader.exec_module(mod)
292-
return mod
288+
return _load_module_from_path(name, cache_path)
293289

294290

295291
# ------------------------
@@ -308,12 +304,12 @@ def __init__(self):
308304
dirname = os.path.dirname(os.path.realpath(__file__))
309305
# we save `spirv_utils` module so that the destructor is not called prematurely, which will unload the dll
310306
# and can cause `Fatal Python error: Segmentation fault`
311-
self.mod = compile_module_from_src(Path(os.path.join(dirname, "driver.c")).read_text(), "spirv_utils")
312-
self.load_binary = self.mod.load_binary
313-
self.get_device_properties = self.mod.get_device_properties
314-
self.device_count = self.mod.init_devices(self.get_sycl_queue())
315-
self.wait_on_sycl_queue = self.mod.wait_on_sycl_queue
316-
self.has_opencl_extension = self.mod.has_opencl_extension
307+
mod = compile_module_from_src(src=Path(os.path.join(dirname, "driver.c")).read_text(), name="spirv_utils")
308+
self.load_binary = mod.load_binary
309+
self.get_device_properties = mod.get_device_properties
310+
self.device_count = mod.init_devices(self.get_sycl_queue())
311+
self.wait_on_sycl_queue = mod.wait_on_sycl_queue
312+
self.has_opencl_extension = mod.has_opencl_extension
317313

318314
def get_current_device(self):
319315
import torch
@@ -369,6 +365,8 @@ def ty_to_cpp(ty):
369365
"fp64": "pack_fp64",
370366
}
371367

368+
_BASE_ARGS_FORMAT = "iiiOOOOOO"
369+
372370

373371
def make_launcher(constants, signature):
374372

@@ -411,12 +409,11 @@ def format_of(ty):
411409
}[ty_to_cpp(ty)]
412410

413411
args_format = ''.join([format_of(ty) for ty in signature.values()])
414-
format = "iiiOOOOOO" + args_format
412+
format = _BASE_ARGS_FORMAT + args_format
415413
signature = ','.join(map(_serialize_signature, signature.values()))
416414
signature = list(filter(bool, signature.split(',')))
417415
signature = {i: s for i, s in enumerate(signature)}
418416
args_list = ', ' + ', '.join(f"&_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else ''
419-
420417
# Record the end of regular arguments;
421418
# subsequent arguments are architecture-specific descriptors.
422419
arg_decl_list = []
@@ -706,15 +703,6 @@ def format_of(ty):
706703
return src
707704

708705

709-
def serialize_kernel_metadata(arg, args_dict):
710-
args_dict['num_warps'] = arg.num_warps
711-
args_dict['threads_per_warp'] = arg.threads_per_warp
712-
args_dict['shared_memory'] = arg.shared
713-
args_dict['kernel_name'] = arg.name
714-
args_dict['spv_name'] = f"{arg.name}.spv"
715-
args_dict['build_flags'] = arg.build_flags
716-
717-
718706
def serialize_args(args, constants, signature):
719707
import torch
720708
import numbers
@@ -723,6 +711,14 @@ def serialize_args(args, constants, signature):
723711
os.makedirs(dir_path)
724712
print(f"Path to directory consisting of SPIR-V Runner data: {dir_path}")
725713

714+
def serialize_kernel_metadata(arg, args_dict):
715+
args_dict['num_warps'] = arg.num_warps
716+
args_dict['threads_per_warp'] = arg.threads_per_warp
717+
args_dict['shared_memory'] = arg.shared
718+
args_dict['kernel_name'] = arg.name
719+
args_dict['spv_name'] = f"{arg.name}.spv"
720+
args_dict['build_flags'] = arg.build_flags
721+
726722
cnt = 0
727723
args_dict = {"gridX": int(args[cnt]), "gridY": int(args[cnt + 1]), "gridZ": int(args[cnt + 2])}
728724
# 3: stream
@@ -774,7 +770,7 @@ def __init__(self, src, metadata):
774770
self.constants = {arg_idx(idx): value for idx, value in constants.items()}
775771
self.signature = {idx: value for idx, value in src.signature.items()}
776772
src = make_launcher(self.constants, self.signature)
777-
self.mod = compile_module_from_src(src, "__triton_launcher")
773+
self.mod = compile_module_from_src(src=src, name="__triton_launcher")
778774
# Serialize KernelArguments for SPIR-V Runner
779775
self.serialize_kernel_args = knobs.intel.dump_spirv_kernel_args
780776

@@ -788,6 +784,7 @@ class XPUDriver(DriverBase):
788784

789785
def __init__(self):
790786
self.launcher_cls = XPULauncher
787+
super().__init__()
791788

792789
def __getattr__(self, name):
793790
# Lazily initialize utils to avoid unnecessary XPU runtime invocations.
@@ -805,45 +802,46 @@ def get_current_stream(self, device):
805802
import torch
806803
return torch.xpu.current_stream().sycl_queue
807804

808-
def update_advanced_features(self, device, dev_property):
809-
if knobs.intel.device_extensions:
810-
# May be useful when using the `TRITON INTEL_DEVICE_ARCH` environment variable
811-
# to be able to flexibly turn on/off the advanced feature.
812-
supported_extensions = set()
813-
supported_extensions.update(knobs.intel.device_extensions.split(" "))
814-
dev_property[
815-
"has_subgroup_matrix_multiply_accumulate"] = "cl_intel_subgroup_matrix_multiply_accumulate" in supported_extensions
816-
dev_property[
817-
"has_subgroup_matrix_multiply_accumulate_tensor_float32"] = "cl_intel_subgroup_matrix_multiply_accumulate_tensor_float32" in supported_extensions
818-
dev_property["has_subgroup_2d_block_io"] = "cl_intel_subgroup_2d_block_io" in supported_extensions
819-
dev_property["has_bfloat16_conversions"] = "cl_intel_bfloat16_conversions" in supported_extensions
820-
else:
821-
check = self.utils.has_opencl_extension
822-
# FIXME: eventually even LTS driver will support OpenCL extensions.
823-
# Please remove this after upgrading to a new version.
824-
# https://github.com/intel/intel-xpu-backend-for-triton/issues/4708
825-
is_lts = "1.3" in dev_property["driver_version"]
826-
dev_property["has_subgroup_matrix_multiply_accumulate"] = check(
827-
device, b"cl_intel_subgroup_matrix_multiply_accumulate") if not is_lts else False
828-
dev_property["has_subgroup_matrix_multiply_accumulate_tensor_float32"] = check(
829-
device, b"cl_intel_subgroup_matrix_multiply_accumulate_tensor_float32") if not is_lts else False
830-
dev_property["has_subgroup_2d_block_io"] = check(device,
831-
b"cl_intel_subgroup_2d_block_io") if not is_lts else False
832-
dev_property["has_bfloat16_conversions"] = check(device,
833-
b"cl_intel_bfloat16_conversions") if not is_lts else False
834-
835805
def get_current_target(self):
836806
import torch
837807
device = self.get_current_device()
838808
dev_property = torch.xpu.get_device_capability(device)
839-
self.update_advanced_features(device, dev_property)
809+
810+
def update_advanced_features(device, dev_property):
811+
if knobs.intel.device_extensions:
812+
# May be useful when using the `TRITON INTEL_DEVICE_ARCH` environment variable
813+
# to be able to flexibly turn on/off the advanced feature.
814+
supported_extensions = set()
815+
supported_extensions.update(knobs.intel.device_extensions.split(" "))
816+
dev_property[
817+
"has_subgroup_matrix_multiply_accumulate"] = "cl_intel_subgroup_matrix_multiply_accumulate" in supported_extensions
818+
dev_property[
819+
"has_subgroup_matrix_multiply_accumulate_tensor_float32"] = "cl_intel_subgroup_matrix_multiply_accumulate_tensor_float32" in supported_extensions
820+
dev_property["has_subgroup_2d_block_io"] = "cl_intel_subgroup_2d_block_io" in supported_extensions
821+
dev_property["has_bfloat16_conversions"] = "cl_intel_bfloat16_conversions" in supported_extensions
822+
else:
823+
check = self.utils.has_opencl_extension
824+
# FIXME: eventually even LTS driver will support OpenCL extensions.
825+
# Please remove this after upgrading to a new version.
826+
# https://github.com/intel/intel-xpu-backend-for-triton/issues/4708
827+
is_lts = "1.3" in dev_property["driver_version"]
828+
dev_property["has_subgroup_matrix_multiply_accumulate"] = check(
829+
device, b"cl_intel_subgroup_matrix_multiply_accumulate") if not is_lts else False
830+
dev_property["has_subgroup_matrix_multiply_accumulate_tensor_float32"] = check(
831+
device, b"cl_intel_subgroup_matrix_multiply_accumulate_tensor_float32") if not is_lts else False
832+
dev_property["has_subgroup_2d_block_io"] = check(
833+
device, b"cl_intel_subgroup_2d_block_io") if not is_lts else False
834+
dev_property["has_bfloat16_conversions"] = check(
835+
device, b"cl_intel_bfloat16_conversions") if not is_lts else False
836+
837+
update_advanced_features(device, dev_property)
840838
return GPUTarget("xpu", dev_property, warp_size=32)
841839

842840
def build_proton_help_lib(self):
843841
from triton.backends.intel.driver import compile_module_from_src
844842

845843
dirname = os.path.dirname(os.path.realpath(__file__))
846-
return compile_module_from_src(Path(dirname).joinpath("proton_utils.cpp").read_text(), "proton_utils")
844+
return compile_module_from_src(src=Path(dirname).joinpath("proton_utils.cpp").read_text(), name="proton_utils")
847845

848846
def get_active_torch_device(self):
849847
import torch

0 commit comments

Comments
 (0)