9
9
from functools import cached_property
10
10
11
11
from triton import knobs
12
- from triton .runtime .build import _build , platform_key
12
+ from triton .runtime .build import _build , platform_key , _load_module_from_path
13
13
from triton .runtime .cache import get_cache_manager
14
14
from triton .backends .compiler import GPUTarget
15
15
from triton .backends .driver import DriverBase
@@ -252,7 +252,7 @@ def __del__(self):
252
252
ctypes .windll .kernel32 .FreeLibrary (handle )
253
253
254
254
255
- def compile_module_from_src (src , name ):
255
+ def compile_module_from_src (src : str , name : str ):
256
256
hasher = hashlib .sha256 (__CACHE_VERSION .encode ("utf-8" ))
257
257
hasher .update ((src + platform_key ()).encode ("utf-8" ))
258
258
key = hasher .hexdigest ()
@@ -278,18 +278,14 @@ def compile_module_from_src(src, name):
278
278
279
279
if name == 'arch_utils' :
280
280
return ArchParser (cache_path )
281
- elif name == 'spirv_utils' :
281
+ if name == 'spirv_utils' :
282
282
return SpirvUtils (cache_path )
283
- elif name == '__triton_launcher' :
283
+ if name == '__triton_launcher' :
284
284
return TritonLauncher (cache_path )
285
- elif name == 'proton_utils' :
285
+ if name == 'proton_utils' :
286
286
return cache_path
287
287
288
- import importlib .util
289
- spec = importlib .util .spec_from_file_location (name , cache_path )
290
- mod = importlib .util .module_from_spec (spec )
291
- spec .loader .exec_module (mod )
292
- return mod
288
+ return _load_module_from_path (name , cache_path )
293
289
294
290
295
291
# ------------------------
@@ -308,12 +304,12 @@ def __init__(self):
308
304
dirname = os .path .dirname (os .path .realpath (__file__ ))
309
305
# we save `spirv_utils` module so that the destructor is not called prematurely, which will unload the dll
310
306
# and can cause `Fatal Python error: Segmentation fault`
311
- self . mod = compile_module_from_src (Path (os .path .join (dirname , "driver.c" )).read_text (), "spirv_utils" )
312
- self .load_binary = self . mod .load_binary
313
- self .get_device_properties = self . mod .get_device_properties
314
- self .device_count = self . mod .init_devices (self .get_sycl_queue ())
315
- self .wait_on_sycl_queue = self . mod .wait_on_sycl_queue
316
- self .has_opencl_extension = self . mod .has_opencl_extension
307
+ mod = compile_module_from_src (src = Path (os .path .join (dirname , "driver.c" )).read_text (), name = "spirv_utils" )
308
+ self .load_binary = mod .load_binary
309
+ self .get_device_properties = mod .get_device_properties
310
+ self .device_count = mod .init_devices (self .get_sycl_queue ())
311
+ self .wait_on_sycl_queue = mod .wait_on_sycl_queue
312
+ self .has_opencl_extension = mod .has_opencl_extension
317
313
318
314
def get_current_device (self ):
319
315
import torch
@@ -369,6 +365,8 @@ def ty_to_cpp(ty):
369
365
"fp64" : "pack_fp64" ,
370
366
}
371
367
368
+ _BASE_ARGS_FORMAT = "iiiOOOOOO"
369
+
372
370
373
371
def make_launcher (constants , signature ):
374
372
@@ -411,12 +409,11 @@ def format_of(ty):
411
409
}[ty_to_cpp (ty )]
412
410
413
411
args_format = '' .join ([format_of (ty ) for ty in signature .values ()])
414
- format = "iiiOOOOOO" + args_format
412
+ format = _BASE_ARGS_FORMAT + args_format
415
413
signature = ',' .join (map (_serialize_signature , signature .values ()))
416
414
signature = list (filter (bool , signature .split (',' )))
417
415
signature = {i : s for i , s in enumerate (signature )}
418
416
args_list = ', ' + ', ' .join (f"&_arg{ i } " for i , ty in signature .items ()) if len (signature ) > 0 else ''
419
-
420
417
# Record the end of regular arguments;
421
418
# subsequent arguments are architecture-specific descriptors.
422
419
arg_decl_list = []
@@ -706,15 +703,6 @@ def format_of(ty):
706
703
return src
707
704
708
705
709
- def serialize_kernel_metadata (arg , args_dict ):
710
- args_dict ['num_warps' ] = arg .num_warps
711
- args_dict ['threads_per_warp' ] = arg .threads_per_warp
712
- args_dict ['shared_memory' ] = arg .shared
713
- args_dict ['kernel_name' ] = arg .name
714
- args_dict ['spv_name' ] = f"{ arg .name } .spv"
715
- args_dict ['build_flags' ] = arg .build_flags
716
-
717
-
718
706
def serialize_args (args , constants , signature ):
719
707
import torch
720
708
import numbers
@@ -723,6 +711,14 @@ def serialize_args(args, constants, signature):
723
711
os .makedirs (dir_path )
724
712
print (f"Path to directory consisting of SPIR-V Runner data: { dir_path } " )
725
713
714
+ def serialize_kernel_metadata (arg , args_dict ):
715
+ args_dict ['num_warps' ] = arg .num_warps
716
+ args_dict ['threads_per_warp' ] = arg .threads_per_warp
717
+ args_dict ['shared_memory' ] = arg .shared
718
+ args_dict ['kernel_name' ] = arg .name
719
+ args_dict ['spv_name' ] = f"{ arg .name } .spv"
720
+ args_dict ['build_flags' ] = arg .build_flags
721
+
726
722
cnt = 0
727
723
args_dict = {"gridX" : int (args [cnt ]), "gridY" : int (args [cnt + 1 ]), "gridZ" : int (args [cnt + 2 ])}
728
724
# 3: stream
@@ -774,7 +770,7 @@ def __init__(self, src, metadata):
774
770
self .constants = {arg_idx (idx ): value for idx , value in constants .items ()}
775
771
self .signature = {idx : value for idx , value in src .signature .items ()}
776
772
src = make_launcher (self .constants , self .signature )
777
- self .mod = compile_module_from_src (src , "__triton_launcher" )
773
+ self .mod = compile_module_from_src (src = src , name = "__triton_launcher" )
778
774
# Serialize KernelArguments for SPIR-V Runner
779
775
self .serialize_kernel_args = knobs .intel .dump_spirv_kernel_args
780
776
@@ -788,6 +784,7 @@ class XPUDriver(DriverBase):
788
784
789
785
def __init__ (self ):
790
786
self .launcher_cls = XPULauncher
787
+ super ().__init__ ()
791
788
792
789
def __getattr__ (self , name ):
793
790
# Lazily initialize utils to avoid unnecessary XPU runtime invocations.
@@ -805,45 +802,46 @@ def get_current_stream(self, device):
805
802
import torch
806
803
return torch .xpu .current_stream ().sycl_queue
807
804
808
- def update_advanced_features (self , device , dev_property ):
809
- if knobs .intel .device_extensions :
810
- # May be useful when using the `TRITON INTEL_DEVICE_ARCH` environment variable
811
- # to be able to flexibly turn on/off the advanced feature.
812
- supported_extensions = set ()
813
- supported_extensions .update (knobs .intel .device_extensions .split (" " ))
814
- dev_property [
815
- "has_subgroup_matrix_multiply_accumulate" ] = "cl_intel_subgroup_matrix_multiply_accumulate" in supported_extensions
816
- dev_property [
817
- "has_subgroup_matrix_multiply_accumulate_tensor_float32" ] = "cl_intel_subgroup_matrix_multiply_accumulate_tensor_float32" in supported_extensions
818
- dev_property ["has_subgroup_2d_block_io" ] = "cl_intel_subgroup_2d_block_io" in supported_extensions
819
- dev_property ["has_bfloat16_conversions" ] = "cl_intel_bfloat16_conversions" in supported_extensions
820
- else :
821
- check = self .utils .has_opencl_extension
822
- # FIXME: eventually even LTS driver will support OpenCL extensions.
823
- # Please remove this after upgrading to a new version.
824
- # https://github.com/intel/intel-xpu-backend-for-triton/issues/4708
825
- is_lts = "1.3" in dev_property ["driver_version" ]
826
- dev_property ["has_subgroup_matrix_multiply_accumulate" ] = check (
827
- device , b"cl_intel_subgroup_matrix_multiply_accumulate" ) if not is_lts else False
828
- dev_property ["has_subgroup_matrix_multiply_accumulate_tensor_float32" ] = check (
829
- device , b"cl_intel_subgroup_matrix_multiply_accumulate_tensor_float32" ) if not is_lts else False
830
- dev_property ["has_subgroup_2d_block_io" ] = check (device ,
831
- b"cl_intel_subgroup_2d_block_io" ) if not is_lts else False
832
- dev_property ["has_bfloat16_conversions" ] = check (device ,
833
- b"cl_intel_bfloat16_conversions" ) if not is_lts else False
834
-
835
805
def get_current_target (self ):
836
806
import torch
837
807
device = self .get_current_device ()
838
808
dev_property = torch .xpu .get_device_capability (device )
839
- self .update_advanced_features (device , dev_property )
809
+
810
+ def update_advanced_features (device , dev_property ):
811
+ if knobs .intel .device_extensions :
812
+ # May be useful when using the `TRITON INTEL_DEVICE_ARCH` environment variable
813
+ # to be able to flexibly turn on/off the advanced feature.
814
+ supported_extensions = set ()
815
+ supported_extensions .update (knobs .intel .device_extensions .split (" " ))
816
+ dev_property [
817
+ "has_subgroup_matrix_multiply_accumulate" ] = "cl_intel_subgroup_matrix_multiply_accumulate" in supported_extensions
818
+ dev_property [
819
+ "has_subgroup_matrix_multiply_accumulate_tensor_float32" ] = "cl_intel_subgroup_matrix_multiply_accumulate_tensor_float32" in supported_extensions
820
+ dev_property ["has_subgroup_2d_block_io" ] = "cl_intel_subgroup_2d_block_io" in supported_extensions
821
+ dev_property ["has_bfloat16_conversions" ] = "cl_intel_bfloat16_conversions" in supported_extensions
822
+ else :
823
+ check = self .utils .has_opencl_extension
824
+ # FIXME: eventually even LTS driver will support OpenCL extensions.
825
+ # Please remove this after upgrading to a new version.
826
+ # https://github.com/intel/intel-xpu-backend-for-triton/issues/4708
827
+ is_lts = "1.3" in dev_property ["driver_version" ]
828
+ dev_property ["has_subgroup_matrix_multiply_accumulate" ] = check (
829
+ device , b"cl_intel_subgroup_matrix_multiply_accumulate" ) if not is_lts else False
830
+ dev_property ["has_subgroup_matrix_multiply_accumulate_tensor_float32" ] = check (
831
+ device , b"cl_intel_subgroup_matrix_multiply_accumulate_tensor_float32" ) if not is_lts else False
832
+ dev_property ["has_subgroup_2d_block_io" ] = check (
833
+ device , b"cl_intel_subgroup_2d_block_io" ) if not is_lts else False
834
+ dev_property ["has_bfloat16_conversions" ] = check (
835
+ device , b"cl_intel_bfloat16_conversions" ) if not is_lts else False
836
+
837
+ update_advanced_features (device , dev_property )
840
838
return GPUTarget ("xpu" , dev_property , warp_size = 32 )
841
839
842
840
def build_proton_help_lib (self ):
843
841
from triton .backends .intel .driver import compile_module_from_src
844
842
845
843
dirname = os .path .dirname (os .path .realpath (__file__ ))
846
- return compile_module_from_src (Path (dirname ).joinpath ("proton_utils.cpp" ).read_text (), "proton_utils" )
844
+ return compile_module_from_src (src = Path (dirname ).joinpath ("proton_utils.cpp" ).read_text (), name = "proton_utils" )
847
845
848
846
def get_active_torch_device (self ):
849
847
import torch
0 commit comments