@@ -60,9 +60,7 @@ def builtin_math_functions():
6060 ret_type = "types.Bool"
6161 else:
6262 ret_type = "types.Float"
63- f.write(
64- textwrap.dedent(
65- f"""
63+ f.write(textwrap.dedent(f"""
6664@statement(dialect=dialect)
6765class {name}(ir.Statement):
6866 \"\"\"{name} statement, wrapping the math.{name} function
@@ -71,9 +69,7 @@ class {name}(ir.Statement):
7169 traits = frozenset({{ir.Pure(), lowering2.FromPythonCall()}})
7270{fields}
7371 result: ir.ResultValue = info.result({ret_type})
74- """
75- )
76- )
72+ """))
7773
7874
7975with open(os.path.join(os.path.dirname(__file__), "interp.py"), "w") as f:
@@ -89,23 +85,19 @@ class {name}(ir.Statement):
8985 fields = ", ".join(
9086 [f"values[{idx}]" for idx, _ in enumerate(sig.parameters.keys())]
9187 )
92- implements.append(
93- f"""
88+ implements.append(f"""
9489 @impl(stmts.{name})
9590 def {name}(self, interp, frame: Frame, stmt: stmts.{name}):
9691 values = frame.get_values(stmt.args)
97- return (math.{name}({fields}),)"""
98- )
92+ return (math.{name}({fields}),)""")
9993
10094 # Write the interpreter class
10195 methods = "\n\n".join(implements)
102- f.write(
103- f"""
96+ f.write(f"""
10497@dialect.register
10598class MathMethodTable(MethodTable):
10699{methods}
107- """
108- )
100+ """)
109101
110102# __init__.py
111103with open(os.path.join(os.path.dirname(__file__), "__init__.py"), "w") as f:
@@ -124,14 +116,10 @@ class MathMethodTable(MethodTable):
124116 ret_type = "bool"
125117 else:
126118 ret_type = "float"
127- f.write(
128- textwrap.dedent(
129- f"""
119+ f.write(textwrap.dedent(f"""
130120 @lowering2.wraps(stmts.{name})
131121 def {name}({", ".join(f"{arg}: {ret_type}" for arg in sig.parameters.keys())}) -> {ret_type}: ...
132- """
133- )
134- )
122+ """))
135123 f.write("\n")
136124
137125for file in ["__init__.py", "interp.py", "stmts.py"]:
@@ -178,16 +166,12 @@ def {name}({", ".join(f"{arg}: {ret_type}" for arg in sig.parameters.keys())}) -
178166 args = ", ".join(arg for arg in sig.parameters.keys())
179167 inputs = ", ".join("0.42" for _ in sig.parameters.keys())
180168
181- f.write(
182- textwrap.dedent(
183- f"""
169+ f.write(textwrap.dedent(f"""
184170 @basic
185171 def {name}_func({args}):
186172 return math.{name}({args})
187173
188174 def test_{name}():
189175 truth = pymath.{name}({inputs})
190176 assert ({name}_func({inputs}) - truth) < 1e-6
191- """
192- )
193- )
177+ """))
0 commit comments