1- from __future__ import annotations
2-
31import logging
2+ import os
3+ import sys
4+ import tempfile
5+ from contextlib import ExitStack
6+ from io import StringIO
7+ from typing import Optional
8+
9+ from mlir .ir import StringAttr
10+ from mlir .passmanager import PassManager
11+
12+ from mlir_utils .util import disable_multithreading
413
514logger = logging .getLogger (__name__ )
615
716
17+ class MlirCompilerError (Exception ):
18+ pass
19+
20+
21+ def get_module_name_for_debug_dump (module ):
22+ if "debug_module_name" not in module .operation .attributes :
23+ return "UnnammedModule"
24+ return StringAttr (module .operation .attributes ["debug_module_name" ]).value
25+
26+
27+ def run_pipeline (
28+ module ,
29+ pipeline : str ,
30+ description : Optional [str ] = None ,
31+ enable_ir_printing = False ,
32+ print_pipeline = False ,
33+ ):
34+ """Runs `pipeline` on `module`, with a nice repro report if it fails."""
35+ module_name = get_module_name_for_debug_dump (module )
36+ try :
37+ original_stderr = sys .stderr
38+ sys .stderr = StringIO ()
39+ # Lower module in place to make it ready for compiler backends.
40+ with ExitStack () as stack :
41+ stack .enter_context (module .context )
42+ asm_for_error_report = module .operation .get_asm (
43+ large_elements_limit = 10 ,
44+ enable_debug_info = True ,
45+ )
46+ pm = PassManager .parse (pipeline )
47+ if print_pipeline :
48+ print (pm )
49+ if enable_ir_printing :
50+ stack .enter_context (disable_multithreading ())
51+ pm .enable_ir_printing ()
52+
53+ pm .run (module .operation )
54+ except Exception as e :
55+ print (e , file = sys .stderr )
56+ filename = os .path .join (tempfile .gettempdir (), module_name + ".mlir" )
57+ with open (filename , "w" ) as f :
58+ f .write (asm_for_error_report )
59+ debug_options = "-mlir-print-ir-after-all -mlir-disable-threading"
60+ description = description or f"{ module_name } compile"
61+
62+ message = f"""\
63+ { description } failed with the following diagnostics:
64+
65+ { '*' * 80 }
66+ { sys .stderr .getvalue ().strip ()}
67+ { '*' * 80 }
68+
69+ For developers, the error can be reproduced with:
70+ $ mlir-opt { debug_options } -pass-pipeline='{ pipeline } ' { filename }
71+ """
72+ trimmed_message = "\n " .join ([m .lstrip () for m in message .split ("\n " )])
73+ raise MlirCompilerError (trimmed_message )
74+ finally :
75+ sys .stderr = original_stderr
76+
77+ return module
78+
79+
880class Pipeline :
981 _pipeline : list [str ] = []
1082
@@ -13,17 +85,17 @@ def __init__(self, pipeline=None, wrapper=None):
1385 pipeline = []
1486 self ._pipeline = pipeline
1587
16- def Func (self , p : Pipeline ):
88+ def Func (self , p : " Pipeline" ):
1789 assert isinstance (p , Pipeline )
1890 self ._pipeline .append (f"func.func({ p .materialize (module = False )} )" )
1991 return self
2092
21- def Spirv (self , p : Pipeline ):
93+ def Spirv (self , p : " Pipeline" ):
2294 assert isinstance (p , Pipeline )
2395 self ._pipeline .append (f"spirv.module({ p .materialize (module = False )} )" )
2496 return self
2597
26- def Gpu (self , p : Pipeline ):
98+ def Gpu (self , p : " Pipeline" ):
2799 assert isinstance (p , Pipeline )
28100 self ._pipeline .append (f"gpu.module({ p .materialize (module = False )} )" )
29101 return self
@@ -38,13 +110,6 @@ def materialize(self, module=True):
38110 def __str__ (self ):
39111 return self .materialize ()
40112
41- def __add__ (self , other : Pipeline ):
42- return Pipeline (self ._pipeline + other ._pipeline )
43-
44- def __iadd__ (self , other : Pipeline ):
45- self ._pipeline += other ._pipeline
46- return self
47-
48113 def add_pass (self , pass_name , ** kwargs ):
49114 kwargs = {
50115 k .replace ("_" , "-" ): int (v ) if isinstance (v , bool ) else v
@@ -57,6 +122,7 @@ def add_pass(self, pass_name, **kwargs):
57122 else :
58123 pass_str = f"{ pass_name } "
59124 self ._pipeline .append (pass_str )
125+ return self
60126
61127 def lower_to_llvm_ (self ):
62128 return any (["to-llvm" in p for p in self ._pipeline ])
0 commit comments