11from  triton .backends .compiler  import  BaseBackend , GPUTarget 
2- from  triton ._C .libtriton  import  ir , passes 
2+ from  triton ._C .libtriton  import  ir , passes ,  triton_shared 
33from  dataclasses  import  dataclass 
44from  typing  import  Any , Dict , Tuple 
55from  types  import  ModuleType 
@@ -42,7 +42,7 @@ def _get_sanitizer_type():
4242    if  sanitizer_type  !=  ""  and  sanitizer_type  !=  "asan"  and  sanitizer_type  !=  "tsan" :
4343        # throw error 
4444        raise  Exception (f"TRITON_SHARED_SANITIZER_TYPE { sanitizer_type }   is invalid." )
45-      
45+ 
4646    return  sanitizer_type 
4747
4848def  _ttir_to_ttsharedir (mod ):
@@ -78,41 +78,47 @@ def _ttsharedir_to_llir(ttsharedir: str):
7878        llmlir_path  =  os .path .join (tmpdir , "ll.mlir" )
7979        llir_path  =  os .path .join (tmpdir , "ll.ir" )
8080        Path (ttshared_path ).write_text (ttsharedir )
81-         mlir_opt_path  =  _get_llvm_bin_path ("mlir-opt" )
81+         context  =  ir .context ()
82+         triton_shared .ir .load_dialects (context )
83+         mod  =  ir .parse_mlir_module (ttshared_path , context )
8284        # TritonShared-MLIR to LLVM-MLIR 
83-         subprocess .check_call ([mlir_opt_path , ttshared_path ,
84-             "--convert-linalg-to-affine-loops" ,
85-             # Note: eliminate-empty-tensors fails when there are multiple func.return ops 
86-             # in a single kernel which are the results of early returns. 
87-             # See python/examples/test_early_return.py for examples. 
88-             # We disable this pass for now since performance on CPU isn't the main 
89-             # focus at the moment. 
90-             # "--eliminate-empty-tensors", 
91-             "--empty-tensor-to-alloc-tensor" ,
92-             "--one-shot-bufferize=allow-return-allocs-from-loops=true" ,
93-             "--lower-affine" ,
94-             "--convert-linalg-to-loops" ,
95-             "--expand-strided-metadata" ,
96-             "--convert-scf-to-cf" ,
97-             "--convert-arith-to-llvm" ,
98-             "--convert-math-to-llvm" ,
99-             "--convert-complex-to-llvm" ,
100-             "--convert-vector-to-llvm" ,
101-             "--convert-index-to-llvm" ,
102-             "--memref-expand" ,
103-             "--finalize-memref-to-llvm" ,
104-             "--convert-func-to-llvm" ,
105-             "--convert-cf-to-llvm" ,
106-             # Lowering memrefs creates more affine.apply ops. 
107-             # Lowering these affine ops again creates further arith ops, 
108-             # so we have to run these two passes again here. 
109-             "--lower-affine" ,
110-             "--convert-arith-to-llvm" ,
111-             # Remove all unrealized casts created 
112-             "--reconcile-unrealized-casts" ,
113-             "--mlir-print-debuginfo" ,
114-             "-o" ,
115-             llmlir_path ])
85+ 
86+         pm  =  ir .pass_manager (context )
87+         pm .enable_debug ()
88+         triton_shared .to_llir .add_convert_linalg_to_affine_loops (pm )
89+         # Note: eliminate-empty-tensors fails when there are multiple func.return ops 
90+         # in a single kernel which are the results of early returns. 
91+         # See python/examples/test_early_return.py for examples. 
92+         # We disable this pass for now since performance on CPU isn't the main 
93+         # focus at the moment. 
94+         # triton_shared.to_llir.add_eliminate_empty_tensors(pm) 
95+         triton_shared .to_llir .add_empty_tensor_to_alloc_tensor (pm )
96+         triton_shared .to_llir .add_one_shot_bufferize_with_options (
97+             pm , allow_return_allocs_from_loops = True )
98+         triton_shared .to_llir .add_lower_affine (pm )
99+         triton_shared .to_llir .add_convert_linalg_to_loops (pm )
100+         triton_shared .to_llir .add_expand_strided_metadata (pm )
101+         triton_shared .to_llir .add_convert_scf_to_cf (pm )
102+         triton_shared .to_llir .add_convert_tptr_to_llvm (pm )
103+         triton_shared .to_llir .add_convert_arith_to_llvm (pm )
104+         triton_shared .to_llir .add_convert_math_to_llvm (pm )
105+         triton_shared .to_llir .add_convert_complex_to_llvm (pm )
106+         triton_shared .to_llir .add_convert_vector_to_llvm (pm )
107+         triton_shared .to_llir .add_convert_index_to_llvm (pm )
108+         triton_shared .to_llir .add_memref_expand (pm )
109+         triton_shared .to_llir .add_finalize_memref_to_llvm (pm )
110+         triton_shared .to_llir .add_convert_func_to_llvm (pm )
111+         triton_shared .to_llir .add_convert_cf_to_llvm (pm )
112+         # Lowering memrefs creates more affine.apply ops. 
113+         # Lowering these affine ops again creates further arith ops, 
114+         # so we have to run these two passes again here. 
115+         triton_shared .to_llir .add_lower_affine (pm )
116+         triton_shared .to_llir .add_convert_arith_to_llvm (pm )
117+         # Remove all unrealized casts created 
118+         triton_shared .to_llir .add_reconcile_unrealized_casts (pm )
119+         pm .run (mod )
120+ 
121+         Path (llmlir_path ).write_text (str (mod ))
116122
117123        # LLVM-MLIR to LLVM-IR 
118124        mlir_translate_path  =  _get_llvm_bin_path ("mlir-translate" )
@@ -145,16 +151,16 @@ def _llir_to_bin(llir: str, metadata):
145151            # using a sanitizer 
146152            # invoke pass to append sanitizer attributes 
147153            instrumented_src_path  =  os .path .join (tmpdir , "kernel-instrumented.ll" )
148-          
154+ 
149155            opt_path  =  _get_llvm_bin_path ("opt" )
150156            top_level_triton_path  =  os .path .dirname (triton .__file__ )
151157            sanitizer_attributes_pass_path  =  str (next (Path (top_level_triton_path ).rglob ("libSanitizerAttributes.so" ), None ))
152158
153159            if  not  sanitizer_attributes_pass_path :
154160                raise  Exception (f"libSanitizerAttributes.so does not exist." )
155161
156-             subprocess .check_call ([opt_path , "-load-pass-plugin" , sanitizer_attributes_pass_path ,  
157-                 "-passes=sanitizer-attributes" , f"-sanitizer-type={ sanitizer_type }  " , "-S" , src_path ,  
162+             subprocess .check_call ([opt_path , "-load-pass-plugin" , sanitizer_attributes_pass_path ,
163+                 "-passes=sanitizer-attributes" , f"-sanitizer-type={ sanitizer_type }  " , "-S" , src_path ,
158164                "-o" , instrumented_src_path ])
159165
160166            # compile to object file 
@@ -166,12 +172,12 @@ def _llir_to_bin(llir: str, metadata):
166172                subprocess_args .extend (["-g" , "-fsanitize=address" , "-mllvm" , "-asan-stack=0" ])
167173            elif  sanitizer_type  ==  "tsan" :
168174                subprocess_args .extend (["-g" , "-fsanitize=thread" ])
169-                  
175+ 
170176            subprocess .check_call (subprocess_args )
171177        else :
172178            llc_path  =  _get_llvm_bin_path ("llc" )
173179            subprocess .check_call ([llc_path , src_path , "-filetype=obj" , "-relocation-model=pic" , "-o" , dst_path ])
174-          
180+ 
175181        return  Path (dst_path ).read_bytes ()
176182
177183
@@ -265,11 +271,11 @@ def add_stages(self, stages, options, language):
265271        stages ["llir" ] =  lambda  src , metadata : _optimize_llir (_ttsharedir_to_llir (src ))
266272        stages ["obj" ] =  lambda  src , metadata : _llir_to_bin (src , metadata )
267273
268- 
269274    @functools .lru_cache () 
270275    def  hash (self ):
271276        return  self .target 
272277
273278    # The CPU backend does not use any extra python modules, return an empty dictionary 
274279    def  get_module_map (self ) ->  Dict [str , ModuleType ]:
275280        return  {}
281+ 
0 commit comments