@@ -399,19 +399,77 @@ def format_of(ty):
399399 return src
400400
401401
402+ def serialize_kernel_metadata (arg , args_dict ):
403+ args_dict ["num_warps" ] = arg .num_warps
404+ args_dict ["threads_per_warp" ] = arg .threads_per_warp
405+ args_dict ["shared_memory" ] = arg .shared
406+ args_dict ["kernel_name" ] = arg .name
407+ args_dict ["spv_name" ] = f"{ arg .name } .spv"
408+
409+
410+ def serialize_args (args , constants , signature ):
411+ import numbers
412+ dir_path = os .getenv ("TRITON_XPU_DUMP_SPIRV_KERNEL_ARGS" )
413+ if not os .path .exists (dir_path ):
414+ os .makedirs (dir_path )
415+ print (f"Path to directory consisting of SPIR-V Runner data: { dir_path } " )
416+
417+ cnt = 0
418+ args_dict = {"gridX" : args [cnt ], "gridY" : args [cnt + 1 ], "gridZ" : args [cnt + 2 ]}
419+ args_dict ["argument_list" ] = []
420+ counts = {"tensors" : 0 , "scalars" : 0 , "karg_cnt" : 0 }
421+ cnt = 4
422+ for arg in args [cnt :]:
423+ if type (arg ).__name__ == "KernelMetadata" :
424+ serialize_kernel_metadata (arg , args_dict )
425+
426+ if isinstance (arg , torch .Tensor ):
427+ cpu_tensor = arg .cpu ()
428+ tensor_path = os .path .join (dir_path , f"tensor_{ counts ['tensors' ]} .pt" )
429+ with open (tensor_path , "wb" ) as f :
430+ torch .save (cpu_tensor , f )
431+ new_arg = {
432+ "name" : f"tensor_{ counts ['tensors' ]} " , "type" : "tensor" , "dtype" : str (arg .dtype ), "ctype" :
433+ signature [counts ["karg_cnt" ]]
434+ }
435+ args_dict ["argument_list" ].append (new_arg )
436+ counts ["karg_cnt" ] += 1
437+ counts ["tensors" ] += 1
438+
439+ if isinstance (arg , numbers .Number ):
440+ if counts ["karg_cnt" ] not in constants :
441+ new_arg = {
442+ "name" : f"scalarArg_{ counts ['scalars' ]} " , "type" : "scalar" , "value" : args [cnt ], "ctype" :
443+ signature [counts ["karg_cnt" ]]
444+ }
445+ args_dict ["argument_list" ].append (new_arg )
446+ counts ["karg_cnt" ] += 1
447+ counts ["scalars" ] += 1
448+ cnt += 1
449+ # Dump argument info as a JSON file
450+ json_path = os .path .join (dir_path , "args_data.json" )
451+ with open (json_path , "w" , encoding = "utf-8" ) as json_file :
452+ import json
453+ json .dump (args_dict , json_file , indent = 4 )
454+
455+
402456class XPULauncher :
403457
404458 def __init__ (self , src , metadata ): # pylint: disable=unused-argument
405459 ids = {"ids_of_const_exprs" : src .fn .constexprs if hasattr (src , "fn" ) else tuple ()}
406460 constants = src .constants if hasattr (src , "constants" ) else {}
407461 cst_key = lambda i : src .fn .arg_names .index (i ) if isinstance (i , str ) else i
408- constants = {cst_key (key ): value for key , value in constants .items ()}
409- signature = {cst_key (key ): value for key , value in src .signature .items ()}
410- src = make_launcher (constants , signature , ids )
462+ self . constants = {cst_key (key ): value for key , value in constants .items ()}
463+ self . signature = {cst_key (key ): value for key , value in src .signature .items ()}
464+ src = make_launcher (self . constants , self . signature , ids )
411465 mod = compile_module_from_src (src , "__triton_launcher" )
412466 self .launch = mod .launch
413467
414468 def __call__ (self , * args , ** kwargs ):
469+ # Serialize KernelArguments for SPIR-V Runner
470+ serialize_kernel_args = os .getenv ("TRITON_XPU_DUMP_SPIRV_KERNEL_ARGS" , None )
471+ if serialize_kernel_args :
472+ serialize_args (args , self .constants , self .signature )
415473 self .launch (* args , ** kwargs )
416474
417475
0 commit comments