Skip to content

Commit 7a4f7e2

Browse files
authored
fix wrong return types for some of the math dialect stmts (#474)
1 parent 599ad1b commit 7a4f7e2

File tree

3 files changed

+16
-8
lines changed

3 files changed

+16
-8
lines changed

src/kirin/dialects/math/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -89,15 +89,15 @@ def gamma(x: float) -> float: ...
8989

9090

9191
@lowering.wraps(stmts.isfinite)
92-
def isfinite(x: float) -> float: ...
92+
def isfinite(x: float) -> bool: ...
9393

9494

9595
@lowering.wraps(stmts.isinf)
96-
def isinf(x: float) -> float: ...
96+
def isinf(x: float) -> bool: ...
9797

9898

9999
@lowering.wraps(stmts.isnan)
100-
def isnan(x: float) -> float: ...
100+
def isnan(x: float) -> bool: ...
101101

102102

103103
@lowering.wraps(stmts.lgamma)

src/kirin/dialects/math/_gen.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,10 @@ def builtin_math_functions():
5656
for arg in sig.parameters.keys()
5757
]
5858
)
59+
if "is" in name:
60+
ret_type = "types.Bool"
61+
else:
62+
ret_type = "types.Float"
5963
f.write(
6064
textwrap.dedent(
6165
f"""
@@ -66,7 +70,7 @@ class {name}(ir.Statement):
6670
name = "{name}"
6771
traits = frozenset({{ir.Pure(), lowering2.FromPythonCall()}})
6872
{fields}
69-
result: ir.ResultValue = info.result(types.Float)
73+
result: ir.ResultValue = info.result({ret_type})
7074
"""
7175
)
7276
)
@@ -116,11 +120,15 @@ class MathMethodTable(MethodTable):
116120
f.write("from kirin import lowering2\n")
117121

118122
for name, obj, sig in builtin_math_functions():
123+
if "is" in name:
124+
ret_type = "bool"
125+
else:
126+
ret_type = "float"
119127
f.write(
120128
textwrap.dedent(
121129
f"""
122130
@lowering2.wraps(stmts.{name})
123-
def {name}({", ".join(f"{arg}: float" for arg in sig.parameters.keys())}) -> float: ...
131+
def {name}({", ".join(f"{arg}: {ret_type}" for arg in sig.parameters.keys())}) -> {ret_type}: ...
124132
"""
125133
)
126134
)

src/kirin/dialects/math/stmts.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ class isfinite(ir.Statement):
204204
name = "isfinite"
205205
traits = frozenset({ir.Pure(), lowering.FromPythonCall()})
206206
x: ir.SSAValue = info.argument(types.Float)
207-
result: ir.ResultValue = info.result(types.Float)
207+
result: ir.ResultValue = info.result(types.Bool)
208208

209209

210210
@statement(dialect=dialect)
@@ -214,7 +214,7 @@ class isinf(ir.Statement):
214214
name = "isinf"
215215
traits = frozenset({ir.Pure(), lowering.FromPythonCall()})
216216
x: ir.SSAValue = info.argument(types.Float)
217-
result: ir.ResultValue = info.result(types.Float)
217+
result: ir.ResultValue = info.result(types.Bool)
218218

219219

220220
@statement(dialect=dialect)
@@ -224,7 +224,7 @@ class isnan(ir.Statement):
224224
name = "isnan"
225225
traits = frozenset({ir.Pure(), lowering.FromPythonCall()})
226226
x: ir.SSAValue = info.argument(types.Float)
227-
result: ir.ResultValue = info.result(types.Float)
227+
result: ir.ResultValue = info.result(types.Bool)
228228

229229

230230
@statement(dialect=dialect)

0 commit comments

Comments
 (0)