@@ -127,6 +127,10 @@ def parse_target(self, tgt_prop) -> dict:
127127 dev_prop ['has_subgroup_2d_block_io' ] = tgt_prop .get ('has_subgroup_2d_block_io' , False )
128128 dev_prop ['has_bfloat16_conversions' ] = tgt_prop .get ('has_bfloat16_conversions' , True )
129129
130+ if self .device_arch in self .device_props :
131+ dev_prop .update (self .device_props [self .device_arch ])
132+ return dev_prop
133+
130134 return dev_prop
131135
132136 def parse_options (self , opts ) -> Any :
@@ -202,85 +206,27 @@ def get_split_barrier_scope(opt):
202206 split_barriers_scope = intel .SplitBarrierScope .Subgroup
203207 return split_barriers_scope
204208
205- @classmethod
206- def create_pass_manager (cls , context , add_passes = []):
207- pm = ir .pass_manager (context )
208- pm .enable_debug ()
209- for p in add_passes :
210- if p is None :
211- continue
212- elif isinstance (p , tuple ):
213- p [0 ](pm , * p [1 :])
214- else :
215- p (pm )
216- return pm
217-
218- @classmethod
219- def get_ttir_passes (cls , opt ):
220- return [
221- passes .common .add_inliner ,
222- intel .passes .ttir .add_convert_tdesc_to_block_pointer ,
223- passes .ttir .add_rewrite_tensor_descriptor_to_pointer ,
224- passes .common .add_cse ,
225- passes .common .add_licm ,
226- intel .passes .ttir .add_remove_masks ,
227- intel .passes .ttir .add_fuse_reshape ,
228- passes .common .add_canonicalizer ,
229- passes .ttir .add_combine ,
230- passes .ttir .add_reorder_broadcast ,
231- passes .common .add_cse ,
232- passes .common .add_symbol_dce ,
233- passes .ttir .add_loop_unroll ,
234- ]
235-
236209 @classmethod
237210 @track
238211 def make_ttir (cls , mod , metadata , opt ):
239- pm = cls .create_pass_manager (mod .context , cls .get_ttir_passes (opt ))
212+ pm = ir .pass_manager (mod .context )
213+ pm .enable_debug ()
214+ passes .common .add_inliner (pm )
215+ intel .passes .ttir .add_convert_tdesc_to_block_pointer (pm )
216+ passes .ttir .add_rewrite_tensor_descriptor_to_pointer (pm )
217+ passes .common .add_cse (pm )
218+ passes .common .add_licm (pm )
219+ intel .passes .ttir .add_remove_masks (pm )
220+ intel .passes .ttir .add_fuse_reshape (pm )
221+ passes .common .add_canonicalizer (pm )
222+ passes .ttir .add_combine (pm )
223+ passes .ttir .add_reorder_broadcast (pm )
224+ passes .common .add_cse (pm )
225+ passes .common .add_symbol_dce (pm )
226+ passes .ttir .add_loop_unroll (pm )
240227 pm .run (mod , 'make_ttir' )
241228 return mod
242229
243- @classmethod
244- def get_ttgir_passes (cls , opt ):
245- # fmt: off
246- return [
247- (passes .ttir .add_convert_to_ttgpuir , "xpu" , opt .num_warps , opt .warp_size , opt .num_ctas ),
248- # optimize TTGIR
249- intel .passes .ttgpuir .add_coalesce ,
250- intel .passes .ttgpuir .add_remove_layout_conversions ,
251-
252- intel .passes .ttgpuir .add_accelerate_matmul ,
253- intel .passes .ttgpuir .add_materialize_block_pointer ,
254- intel .passes .ttgpuir .add_remove_layout_conversions ,
255- intel .passes .ttgpuir .add_optimize_dot_operands ,
256- (intel .passes .ttgpuir .add_pipeline , opt .num_stages , cls .get_split_barrier_scope (opt )),
257-
258- intel .passes .ttgpuir .add_reduce_variable_liveness if opt .reduce_variable_liveness else None ,
259-
260- passes .ttgpuir .add_fuse_nested_loops ,
261-
262- passes .common .add_canonicalizer ,
263- passes .ttir .add_triton_licm ,
264- passes .common .add_canonicalizer ,
265- passes .ttgpuir .add_combine_tensor_select_and_if ,
266-
267- passes .ttgpuir .add_optimize_thread_locality ,
268- (passes .ttgpuir .add_optimize_dot_operands , True ),
269- passes .common .add_cse ,
270- passes .ttgpuir .add_prefetch ,
271- (passes .ttgpuir .add_optimize_dot_operands , True ),
272- intel .passes .ttgpuir .add_remove_layout_conversions ,
273- intel .passes .ttgpuir .add_reduce_data_duplication ,
274- passes .ttgpuir .add_reorder_instructions ,
275- passes .common .add_cse ,
276- passes .common .add_symbol_dce ,
277- passes .common .add_sccp ,
278- passes .common .add_canonicalizer ,
279- intel .passes .ttgpuir .add_optimize_reduction_locality if knobs .intel .opt_reduction_locality else None ,
280- (intel .passes .arith .add_arith_emulate_unsupported_floats , ["bf16" ], "f32" )
281- ]
282- # fmt: on
283-
284230 @classmethod
285231 @track
286232 def make_ttgir (cls , mod , metadata , opt , properties ):
@@ -291,7 +237,8 @@ def make_ttgir(cls, mod, metadata, opt, properties):
291237 cluster_info .clusterDimZ = opt .cluster_dims [2 ]
292238
293239 # Annotate module with information required by subsequent transformations.
294- pm = cls .create_pass_manager (mod .context )
240+ pm = ir .pass_manager (mod .context )
241+ pm .enable_debug ()
295242 module_opts = intel .passes .ttgpuir .AnnotateModuleOptions ()
296243 cls .annotate_module (module_opts , properties , opt )
297244 intel .passes .ttgpuir .add_triton_annotate_module (pm , module_opts )
@@ -301,7 +248,44 @@ def make_ttgir(cls, mod, metadata, opt, properties):
301248 opt .warp_size = intel .get_threads_per_warp (mod )
302249 cls .validate_options (opt , properties )
303250
304- pm = cls .create_pass_manager (mod .context , cls .get_ttgir_passes (opt ))
251+ pm = ir .pass_manager (mod .context )
252+ pm .enable_debug ()
253+ passes .ttir .add_convert_to_ttgpuir (pm , "xpu" , opt .num_warps , opt .warp_size , opt .num_ctas )
254+ # optimize TTGIR
255+ intel .passes .ttgpuir .add_coalesce (pm )
256+ intel .passes .ttgpuir .add_remove_layout_conversions (pm )
257+
258+ intel .passes .ttgpuir .add_accelerate_matmul (pm )
259+ intel .passes .ttgpuir .add_materialize_block_pointer (pm )
260+ intel .passes .ttgpuir .add_remove_layout_conversions (pm )
261+ intel .passes .ttgpuir .add_optimize_dot_operands (pm )
262+ intel .passes .ttgpuir .add_pipeline (pm , opt .num_stages , XPUBackend .get_split_barrier_scope (opt ))
263+
264+ if (opt .reduce_variable_liveness ):
265+ intel .passes .ttgpuir .add_reduce_variable_liveness (pm )
266+
267+ passes .ttgpuir .add_fuse_nested_loops (pm )
268+
269+ passes .common .add_canonicalizer (pm )
270+ passes .ttir .add_triton_licm (pm )
271+ passes .common .add_canonicalizer (pm )
272+ passes .ttgpuir .add_combine_tensor_select_and_if (pm )
273+
274+ passes .ttgpuir .add_optimize_thread_locality (pm )
275+ passes .ttgpuir .add_optimize_dot_operands (pm , True )
276+ passes .common .add_cse (pm )
277+ passes .ttgpuir .add_prefetch (pm )
278+ passes .ttgpuir .add_optimize_dot_operands (pm , True )
279+ intel .passes .ttgpuir .add_remove_layout_conversions (pm )
280+ intel .passes .ttgpuir .add_reduce_data_duplication (pm )
281+ passes .ttgpuir .add_reorder_instructions (pm )
282+ passes .common .add_cse (pm )
283+ passes .common .add_symbol_dce (pm )
284+ passes .common .add_sccp (pm )
285+ passes .common .add_canonicalizer (pm )
286+ if knobs .intel .opt_reduction_locality :
287+ intel .passes .ttgpuir .add_optimize_reduction_locality (pm )
288+ intel .passes .arith .add_arith_emulate_unsupported_floats (pm , ["bf16" ], "f32" )
305289 pm .run (mod , 'make_ttgir' )
306290 metadata ["cluster_dims" ] = (cluster_info .clusterDimX , cluster_info .clusterDimY , cluster_info .clusterDimZ )
307291 return mod
@@ -322,31 +306,6 @@ def gluon_to_ttgir(self, src, metadata, options):
322306 metadata ["tensordesc_meta" ] = mod .get_tensordesc_metadata ()
323307 return mod
324308
325- @classmethod
326- def get_llir_passes (cls , opt , mod ):
327- # fmt: off
328- return [
329- passes .convert .add_scf_to_cf ,
330- passes .gluon .add_inliner ,
331- passes .convert .add_index_to_llvmir ,
332- intel .passes .ttgpuir .add_allocate_shared_memory ,
333- passes .ttgpuir .add_allocate_global_scratch_memory ,
334- # instrumentation point here so we can override IRs above (e.g., ttir and ttgir)
335- lambda pm : cls .instrumentation .patch ("ttgpuir_to_llvmir" , pm , mod .context ) if cls .instrumentation else None ,
336- intel .passes .ttgpuir .add_to_llvmir ,
337- intel .passes .ttgpuir .add_gen_to_llvm ,
338- passes .common .add_canonicalizer ,
339- intel .passes .ttgpuir .add_rewrite_stack_ptr ,
340- passes .common .add_cse ,
341- passes .convert .add_arith_to_llvmir ,
342- passes .common .add_canonicalizer ,
343- passes .common .add_cse ,
344- passes .common .add_symbol_dce ,
345- None if knobs .compilation .disable_line_info or knobs .compilation .dump_ir_extract_di_local_variables else passes .llvmir .add_di_scope ,
346- lambda pm : cls .instrumentation .patch ("llvmir_to_llvm" , pm , mod .context ) if cls .instrumentation else None ,
347- ]
348- # fmt: on
349-
350309 @classmethod
351310 def optimize_llvm_mod (cls , llvm_mod , options ):
352311 intel .set_spv_target_triple (llvm_mod )
@@ -358,21 +317,50 @@ def optimize_llvm_mod(cls, llvm_mod, options):
358317 def make_llir (cls , src , metadata , options ):
359318 mod = src
360319 # TritonGPU -> LLVM-IR (MLIR)
361- pm = cls .create_pass_manager (mod .context , cls .get_llir_passes (options , mod ))
320+ pm = ir .pass_manager (mod .context )
321+ pm .enable_debug ()
322+
323+ passes .convert .add_scf_to_cf (pm )
324+ passes .gluon .add_inliner (pm )
325+ passes .convert .add_index_to_llvmir (pm )
326+ intel .passes .ttgpuir .add_allocate_shared_memory (pm )
327+ passes .ttgpuir .add_allocate_global_scratch_memory (pm )
328+ # instrumentation point here so we can override IRs above (e.g., ttir and ttgir)
329+ if cls .instrumentation :
330+ cls .instrumentation .patch ("ttgpuir_to_llvmir" , pm , mod .context )
331+ intel .passes .ttgpuir .add_to_llvmir (pm )
332+ intel .passes .ttgpuir .add_gen_to_llvm (pm )
333+ passes .common .add_canonicalizer (pm )
334+ intel .passes .ttgpuir .add_rewrite_stack_ptr (pm )
335+ passes .common .add_cse (pm )
336+ passes .convert .add_arith_to_llvmir (pm )
337+ passes .common .add_canonicalizer (pm )
338+ passes .common .add_cse (pm )
339+ passes .common .add_symbol_dce (pm )
340+
341+ if not knobs .compilation .disable_line_info and not knobs .compilation .dump_ir_extract_di_local_variables :
342+ passes .llvmir .add_di_scope (pm )
343+
344+ if cls .instrumentation :
345+ cls .instrumentation .patch ("llvmir_to_llvm" , pm , mod .context )
362346 pm .run (mod , 'make_llir' )
363347
364348 if knobs .compilation .dump_ir_extract_di_local_variables :
365349 # comments below on why separate it
366350 if not knobs .compilation .disable_line_info :
367- pm = cls .create_pass_manager (mod .context , [passes .llvmir .add_di_scope ])
351+ pm = ir .pass_manager (mod .context )
352+ pm .enable_debug ()
353+ passes .llvmir .add_di_scope (pm )
368354 pm .run (mod , 'make_llir.disable_line_info' )
369355
370356 # insert dbg intrinsic with several DI Attribute including source
371357 # var name and type info note: unknown reason for now, but this
372358 # pass and add_di_scope has to be run separately, otherwise if we
373359 # put them into previous pipline, it trigger a segmentfault without
374360 # any error message; could be due to a bug in mlir or pybind11
375- pm = cls .create_pass_manager (mod .context , [passes .llvmir .add_di_local_variable ])
361+ pm = ir .pass_manager (mod .context )
362+ pm .enable_debug ()
363+ passes .llvmir .add_di_local_variable (pm )
376364 pm .run (mod , 'make_llir.dump_ir_extract_di_local_variables' )
377365
378366 # LLVM-IR (MLIR) -> LLVM-IR (LLVM)
0 commit comments