@@ -60,9 +60,7 @@ def builtin_math_functions():
6060 ret_type = "types.Bool"
6161 else :
6262 ret_type = "types.Float"
63- f .write (
64- textwrap .dedent (
65- f"""
63+ f .write (textwrap .dedent (f"""
6664@statement(dialect=dialect)
6765class { name } (ir.Statement):
6866 \" \" \" { name } statement, wrapping the math.{ name } function
@@ -71,9 +69,7 @@ class {name}(ir.Statement):
7169 traits = frozenset({{ir.Pure(), lowering2.FromPythonCall()}})
7270{ fields }
7371 result: ir.ResultValue = info.result({ ret_type } )
74- """
75- )
76- )
72+ """ ))
7773
7874
7975with open (os .path .join (os .path .dirname (__file__ ), "interp.py" ), "w" ) as f :
@@ -89,23 +85,19 @@ class {name}(ir.Statement):
8985 fields = ", " .join (
9086 [f"values[{ idx } ]" for idx , _ in enumerate (sig .parameters .keys ())]
9187 )
92- implements .append (
93- f"""
88+ implements .append (f"""
9489 @impl(stmts.{ name } )
9590 def { name } (self, interp, frame: Frame, stmt: stmts.{ name } ):
9691 values = frame.get_values(stmt.args)
97- return (math.{ name } ({ fields } ),)"""
98- )
92+ return (math.{ name } ({ fields } ),)""" )
9993
10094 # Write the interpreter class
10195 methods = "\n \n " .join (implements )
102- f .write (
103- f"""
96+ f .write (f"""
10497@dialect.register
10598class MathMethodTable(MethodTable):
10699{ methods }
107- """
108- )
100+ """ )
109101
110102# __init__.py
111103with open (os .path .join (os .path .dirname (__file__ ), "__init__.py" ), "w" ) as f :
@@ -124,14 +116,10 @@ class MathMethodTable(MethodTable):
124116 ret_type = "bool"
125117 else :
126118 ret_type = "float"
127- f .write (
128- textwrap .dedent (
129- f"""
119+ f .write (textwrap .dedent (f"""
130120 @lowering2.wraps(stmts.{ name } )
131121 def { name } ({ ", " .join (f"{ arg } : { ret_type } " for arg in sig .parameters .keys ())} ) -> { ret_type } : ...
132- """
133- )
134- )
122+ """ ))
135123 f .write ("\n " )
136124
137125for file in ["__init__.py" , "interp.py" , "stmts.py" ]:
@@ -178,16 +166,12 @@ def {name}({", ".join(f"{arg}: {ret_type}" for arg in sig.parameters.keys())}) -
178166 args = ", " .join (arg for arg in sig .parameters .keys ())
179167 inputs = ", " .join ("0.42" for _ in sig .parameters .keys ())
180168
181- f .write (
182- textwrap .dedent (
183- f"""
169+ f .write (textwrap .dedent (f"""
184170 @basic
185171 def { name } _func({ args } ):
186172 return math.{ name } ({ args } )
187173
188174 def test_{ name } ():
189175 truth = pymath.{ name } ({ inputs } )
190176 assert ({ name } _func({ inputs } ) - truth) < 1e-6
191- """
192- )
193- )
177+ """ ))
0 commit comments