Skip to content

Commit 82ab83b

Browse files
Removed get_*ir_passes()
1 parent 62610b3 commit 82ab83b

File tree

1 file changed

+91
-103
lines changed

1 file changed

+91
-103
lines changed

third_party/intel/backend/compiler.py

Lines changed: 91 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,10 @@ def parse_target(self, tgt_prop) -> dict:
127127
dev_prop['has_subgroup_2d_block_io'] = tgt_prop.get('has_subgroup_2d_block_io', False)
128128
dev_prop['has_bfloat16_conversions'] = tgt_prop.get('has_bfloat16_conversions', True)
129129

130+
if self.device_arch in self.device_props:
131+
dev_prop.update(self.device_props[self.device_arch])
132+
return dev_prop
133+
130134
return dev_prop
131135

132136
def parse_options(self, opts) -> Any:
@@ -202,85 +206,27 @@ def get_split_barrier_scope(opt):
202206
split_barriers_scope = intel.SplitBarrierScope.Subgroup
203207
return split_barriers_scope
204208

205-
@classmethod
206-
def create_pass_manager(cls, context, add_passes=[]):
207-
pm = ir.pass_manager(context)
208-
pm.enable_debug()
209-
for p in add_passes:
210-
if p is None:
211-
continue
212-
elif isinstance(p, tuple):
213-
p[0](pm, *p[1:])
214-
else:
215-
p(pm)
216-
return pm
217-
218-
@classmethod
219-
def get_ttir_passes(cls, opt):
220-
return [
221-
passes.common.add_inliner,
222-
intel.passes.ttir.add_convert_tdesc_to_block_pointer,
223-
passes.ttir.add_rewrite_tensor_descriptor_to_pointer,
224-
passes.common.add_cse,
225-
passes.common.add_licm,
226-
intel.passes.ttir.add_remove_masks,
227-
intel.passes.ttir.add_fuse_reshape,
228-
passes.common.add_canonicalizer,
229-
passes.ttir.add_combine,
230-
passes.ttir.add_reorder_broadcast,
231-
passes.common.add_cse,
232-
passes.common.add_symbol_dce,
233-
passes.ttir.add_loop_unroll,
234-
]
235-
236209
@classmethod
237210
@track
238211
def make_ttir(cls, mod, metadata, opt):
239-
pm = cls.create_pass_manager(mod.context, cls.get_ttir_passes(opt))
212+
pm = ir.pass_manager(mod.context)
213+
pm.enable_debug()
214+
passes.common.add_inliner(pm)
215+
intel.passes.ttir.add_convert_tdesc_to_block_pointer(pm)
216+
passes.ttir.add_rewrite_tensor_descriptor_to_pointer(pm)
217+
passes.common.add_cse(pm)
218+
passes.common.add_licm(pm)
219+
intel.passes.ttir.add_remove_masks(pm)
220+
intel.passes.ttir.add_fuse_reshape(pm)
221+
passes.common.add_canonicalizer(pm)
222+
passes.ttir.add_combine(pm)
223+
passes.ttir.add_reorder_broadcast(pm)
224+
passes.common.add_cse(pm)
225+
passes.common.add_symbol_dce(pm)
226+
passes.ttir.add_loop_unroll(pm)
240227
pm.run(mod, 'make_ttir')
241228
return mod
242229

243-
@classmethod
244-
def get_ttgir_passes(cls, opt):
245-
# fmt: off
246-
return [
247-
(passes.ttir.add_convert_to_ttgpuir, "xpu", opt.num_warps, opt.warp_size, opt.num_ctas),
248-
# optimize TTGIR
249-
intel.passes.ttgpuir.add_coalesce,
250-
intel.passes.ttgpuir.add_remove_layout_conversions,
251-
252-
intel.passes.ttgpuir.add_accelerate_matmul,
253-
intel.passes.ttgpuir.add_materialize_block_pointer,
254-
intel.passes.ttgpuir.add_remove_layout_conversions,
255-
intel.passes.ttgpuir.add_optimize_dot_operands,
256-
(intel.passes.ttgpuir.add_pipeline, opt.num_stages, cls.get_split_barrier_scope(opt)),
257-
258-
intel.passes.ttgpuir.add_reduce_variable_liveness if opt.reduce_variable_liveness else None,
259-
260-
passes.ttgpuir.add_fuse_nested_loops,
261-
262-
passes.common.add_canonicalizer,
263-
passes.ttir.add_triton_licm,
264-
passes.common.add_canonicalizer,
265-
passes.ttgpuir.add_combine_tensor_select_and_if,
266-
267-
passes.ttgpuir.add_optimize_thread_locality,
268-
(passes.ttgpuir.add_optimize_dot_operands, True),
269-
passes.common.add_cse,
270-
passes.ttgpuir.add_prefetch,
271-
(passes.ttgpuir.add_optimize_dot_operands, True),
272-
intel.passes.ttgpuir.add_remove_layout_conversions,
273-
intel.passes.ttgpuir.add_reduce_data_duplication,
274-
passes.ttgpuir.add_reorder_instructions,
275-
passes.common.add_cse,
276-
passes.common.add_symbol_dce,
277-
passes.common.add_sccp,
278-
passes.common.add_canonicalizer,
279-
intel.passes.ttgpuir.add_optimize_reduction_locality if knobs.intel.opt_reduction_locality else None,
280-
(intel.passes.arith.add_arith_emulate_unsupported_floats, ["bf16"], "f32")
281-
]
282-
# fmt: on
283-
284230
@classmethod
285231
@track
286232
def make_ttgir(cls, mod, metadata, opt, properties):
@@ -291,7 +237,8 @@ def make_ttgir(cls, mod, metadata, opt, properties):
291237
cluster_info.clusterDimZ = opt.cluster_dims[2]
292238

293239
# Annotate module with information required by subsequent transformations.
294-
pm = cls.create_pass_manager(mod.context)
240+
pm = ir.pass_manager(mod.context)
241+
pm.enable_debug()
295242
module_opts = intel.passes.ttgpuir.AnnotateModuleOptions()
296243
cls.annotate_module(module_opts, properties, opt)
297244
intel.passes.ttgpuir.add_triton_annotate_module(pm, module_opts)
@@ -301,7 +248,44 @@ def make_ttgir(cls, mod, metadata, opt, properties):
301248
opt.warp_size = intel.get_threads_per_warp(mod)
302249
cls.validate_options(opt, properties)
303250

304-
pm = cls.create_pass_manager(mod.context, cls.get_ttgir_passes(opt))
251+
pm = ir.pass_manager(mod.context)
252+
pm.enable_debug()
253+
passes.ttir.add_convert_to_ttgpuir(pm, "xpu", opt.num_warps, opt.warp_size, opt.num_ctas)
254+
# optimize TTGIR
255+
intel.passes.ttgpuir.add_coalesce(pm)
256+
intel.passes.ttgpuir.add_remove_layout_conversions(pm)
257+
258+
intel.passes.ttgpuir.add_accelerate_matmul(pm)
259+
intel.passes.ttgpuir.add_materialize_block_pointer(pm)
260+
intel.passes.ttgpuir.add_remove_layout_conversions(pm)
261+
intel.passes.ttgpuir.add_optimize_dot_operands(pm)
262+
intel.passes.ttgpuir.add_pipeline(pm, opt.num_stages, XPUBackend.get_split_barrier_scope(opt))
263+
264+
if (opt.reduce_variable_liveness):
265+
intel.passes.ttgpuir.add_reduce_variable_liveness(pm)
266+
267+
passes.ttgpuir.add_fuse_nested_loops(pm)
268+
269+
passes.common.add_canonicalizer(pm)
270+
passes.ttir.add_triton_licm(pm)
271+
passes.common.add_canonicalizer(pm)
272+
passes.ttgpuir.add_combine_tensor_select_and_if(pm)
273+
274+
passes.ttgpuir.add_optimize_thread_locality(pm)
275+
passes.ttgpuir.add_optimize_dot_operands(pm, True)
276+
passes.common.add_cse(pm)
277+
passes.ttgpuir.add_prefetch(pm)
278+
passes.ttgpuir.add_optimize_dot_operands(pm, True)
279+
intel.passes.ttgpuir.add_remove_layout_conversions(pm)
280+
intel.passes.ttgpuir.add_reduce_data_duplication(pm)
281+
passes.ttgpuir.add_reorder_instructions(pm)
282+
passes.common.add_cse(pm)
283+
passes.common.add_symbol_dce(pm)
284+
passes.common.add_sccp(pm)
285+
passes.common.add_canonicalizer(pm)
286+
if knobs.intel.opt_reduction_locality:
287+
intel.passes.ttgpuir.add_optimize_reduction_locality(pm)
288+
intel.passes.arith.add_arith_emulate_unsupported_floats(pm, ["bf16"], "f32")
305289
pm.run(mod, 'make_ttgir')
306290
metadata["cluster_dims"] = (cluster_info.clusterDimX, cluster_info.clusterDimY, cluster_info.clusterDimZ)
307291
return mod
@@ -322,31 +306,6 @@ def gluon_to_ttgir(self, src, metadata, options):
322306
metadata["tensordesc_meta"] = mod.get_tensordesc_metadata()
323307
return mod
324308

325-
@classmethod
326-
def get_llir_passes(cls, opt, mod):
327-
# fmt: off
328-
return [
329-
passes.convert.add_scf_to_cf,
330-
passes.gluon.add_inliner,
331-
passes.convert.add_index_to_llvmir,
332-
intel.passes.ttgpuir.add_allocate_shared_memory,
333-
passes.ttgpuir.add_allocate_global_scratch_memory,
334-
# instrumentation point here so we can override IRs above (e.g., ttir and ttgir)
335-
lambda pm: cls.instrumentation.patch("ttgpuir_to_llvmir", pm, mod.context) if cls.instrumentation else None,
336-
intel.passes.ttgpuir.add_to_llvmir,
337-
intel.passes.ttgpuir.add_gen_to_llvm,
338-
passes.common.add_canonicalizer,
339-
intel.passes.ttgpuir.add_rewrite_stack_ptr,
340-
passes.common.add_cse,
341-
passes.convert.add_arith_to_llvmir,
342-
passes.common.add_canonicalizer,
343-
passes.common.add_cse,
344-
passes.common.add_symbol_dce,
345-
None if knobs.compilation.disable_line_info or knobs.compilation.dump_ir_extract_di_local_variables else passes.llvmir.add_di_scope,
346-
lambda pm: cls.instrumentation.patch("llvmir_to_llvm", pm, mod.context) if cls.instrumentation else None,
347-
]
348-
# fmt: on
349-
350309
@classmethod
351310
def optimize_llvm_mod(cls, llvm_mod, options):
352311
intel.set_spv_target_triple(llvm_mod)
@@ -358,21 +317,50 @@ def optimize_llvm_mod(cls, llvm_mod, options):
358317
def make_llir(cls, src, metadata, options):
359318
mod = src
360319
# TritonGPU -> LLVM-IR (MLIR)
361-
pm = cls.create_pass_manager(mod.context, cls.get_llir_passes(options, mod))
320+
pm = ir.pass_manager(mod.context)
321+
pm.enable_debug()
322+
323+
passes.convert.add_scf_to_cf(pm)
324+
passes.gluon.add_inliner(pm)
325+
passes.convert.add_index_to_llvmir(pm)
326+
intel.passes.ttgpuir.add_allocate_shared_memory(pm)
327+
passes.ttgpuir.add_allocate_global_scratch_memory(pm)
328+
# instrumentation point here so we can override IRs above (e.g., ttir and ttgir)
329+
if cls.instrumentation:
330+
cls.instrumentation.patch("ttgpuir_to_llvmir", pm, mod.context)
331+
intel.passes.ttgpuir.add_to_llvmir(pm)
332+
intel.passes.ttgpuir.add_gen_to_llvm(pm)
333+
passes.common.add_canonicalizer(pm)
334+
intel.passes.ttgpuir.add_rewrite_stack_ptr(pm)
335+
passes.common.add_cse(pm)
336+
passes.convert.add_arith_to_llvmir(pm)
337+
passes.common.add_canonicalizer(pm)
338+
passes.common.add_cse(pm)
339+
passes.common.add_symbol_dce(pm)
340+
341+
if not knobs.compilation.disable_line_info and not knobs.compilation.dump_ir_extract_di_local_variables:
342+
passes.llvmir.add_di_scope(pm)
343+
344+
if cls.instrumentation:
345+
cls.instrumentation.patch("llvmir_to_llvm", pm, mod.context)
362346
pm.run(mod, 'make_llir')
363347

364348
if knobs.compilation.dump_ir_extract_di_local_variables:
365349
# comments below on why separate it
366350
if not knobs.compilation.disable_line_info:
367-
pm = cls.create_pass_manager(mod.context, [passes.llvmir.add_di_scope])
351+
pm = ir.pass_manager(mod.context)
352+
pm.enable_debug()
353+
passes.llvmir.add_di_scope(pm)
368354
pm.run(mod, 'make_llir.disable_line_info')
369355

370356
# insert dbg intrinsic with several DI Attribute including source
371357
# var name and type info note: unknown reason for now, but this
372358
# pass and add_di_scope has to be run separately, otherwise if we
373359
# put them into previous pipline, it trigger a segmentfault without
374360
# any error message; could be due to a bug in mlir or pybind11
375-
pm = cls.create_pass_manager(mod.context, [passes.llvmir.add_di_local_variable])
361+
pm = ir.pass_manager(mod.context)
362+
pm.enable_debug()
363+
passes.llvmir.add_di_local_variable(pm)
376364
pm.run(mod, 'make_llir.dump_ir_extract_di_local_variables')
377365

378366
# LLVM-IR (MLIR) -> LLVM-IR (LLVM)

0 commit comments

Comments
 (0)