|
1 |
| -from triton.backends.compiler import BaseBackend |
| 1 | +from triton.backends.compiler import BaseBackend, Language |
2 | 2 | from triton._C.libtriton import ir, passes, llvm, intel
|
3 | 3 | from triton.backends.intel.driver import compile_module_from_src
|
4 | 4 | from triton import knobs
|
@@ -315,6 +315,20 @@ def make_ttgir(mod, metadata, opt, properties):
|
315 | 315 | metadata["cluster_dims"] = (cluster_info.clusterDimX, cluster_info.clusterDimY, cluster_info.clusterDimZ)
|
316 | 316 | return mod
|
317 | 317 |
|
| 318 | + @staticmethod |
| 319 | + def ttgir_opt(src, metadata, options): |
| 320 | + mod = src |
| 321 | + pm = ir.pass_manager(mod.context) |
| 322 | + pm.enable_debug() |
| 323 | + |
| 324 | + passes.ttgpuir.add_inliner(pm) |
| 325 | + passes.ttir.add_loop_aware_cse(pm) |
| 326 | + passes.ttgpuir.add_canonicalizer(pm) |
| 327 | + passes.ttgpuir.add_combine_tensor_select_and_if(pm) |
| 328 | + |
| 329 | + pm.run(mod) |
| 330 | + return mod |
| 331 | + |
318 | 332 | @staticmethod
|
319 | 333 | def make_llir(src, metadata, options):
|
320 | 334 | mod = src
|
@@ -444,9 +458,12 @@ def make_spv(src, metadata, options):
|
444 | 458 | return zebin
|
445 | 459 | return spirv
|
446 | 460 |
|
447 |
| - def add_stages(self, stages, options): |
448 |
| - stages["ttir"] = lambda src, metadata: self.make_ttir(src, metadata, options) |
449 |
| - stages["ttgir"] = lambda src, metadata: self.make_ttgir(src, metadata, options, self.properties) |
| 461 | + def add_stages(self, stages, options, language): |
| 462 | + if language == Language.TRITON: |
| 463 | + stages["ttir"] = lambda src, metadata: self.make_ttir(src, metadata, options) |
| 464 | + stages["ttgir"] = lambda src, metadata: self.make_ttgir(src, metadata, options, self.properties) |
| 465 | + elif language == Language.GLUON: |
| 466 | + stages["ttgir"] = lambda src, metadata: self.ttgir_opt(src, metadata, options) |
450 | 467 | stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options)
|
451 | 468 | stages["spv"] = lambda src, metadata: self.make_spv(src, metadata, options)
|
452 | 469 |
|
|
0 commit comments