Skip to content

Commit 09ebc1a

Browse files
[Intel] Add an opt pass pipeline for gluon
Fix test failures from c109dc7. Signed-off-by: Whitney Tsang <[email protected]>
1 parent 83f279b commit 09ebc1a

File tree

1 file changed

+21
-4
lines changed

1 file changed

+21
-4
lines changed

third_party/intel/backend/compiler.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from triton.backends.compiler import BaseBackend
1+
from triton.backends.compiler import BaseBackend, Language
22
from triton._C.libtriton import ir, passes, llvm, intel
33
from triton.backends.intel.driver import compile_module_from_src
44
from triton import knobs
@@ -315,6 +315,20 @@ def make_ttgir(mod, metadata, opt, properties):
315315
metadata["cluster_dims"] = (cluster_info.clusterDimX, cluster_info.clusterDimY, cluster_info.clusterDimZ)
316316
return mod
317317

318+
@staticmethod
319+
def ttgir_opt(src, metadata, options):
320+
mod = src
321+
pm = ir.pass_manager(mod.context)
322+
pm.enable_debug()
323+
324+
passes.ttgpuir.add_inliner(pm)
325+
passes.ttir.add_loop_aware_cse(pm)
326+
passes.ttgpuir.add_canonicalizer(pm)
327+
passes.ttgpuir.add_combine_tensor_select_and_if(pm)
328+
329+
pm.run(mod)
330+
return mod
331+
318332
@staticmethod
319333
def make_llir(src, metadata, options):
320334
mod = src
@@ -444,9 +458,12 @@ def make_spv(src, metadata, options):
444458
return zebin
445459
return spirv
446460

447-
def add_stages(self, stages, options):
448-
stages["ttir"] = lambda src, metadata: self.make_ttir(src, metadata, options)
449-
stages["ttgir"] = lambda src, metadata: self.make_ttgir(src, metadata, options, self.properties)
461+
def add_stages(self, stages, options, language):
462+
if language == Language.TRITON:
463+
stages["ttir"] = lambda src, metadata: self.make_ttir(src, metadata, options)
464+
stages["ttgir"] = lambda src, metadata: self.make_ttgir(src, metadata, options, self.properties)
465+
elif language == Language.GLUON:
466+
stages["ttgir"] = lambda src, metadata: self.ttgir_opt(src, metadata, options)
450467
stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options)
451468
stages["spv"] = lambda src, metadata: self.make_spv(src, metadata, options)
452469

0 commit comments

Comments
 (0)