Skip to content

Commit 45f62c1

Browse files
authored
This PR add vmath dialect (#473)
Add vectorized math dialect. address #471
1 parent 7a4f7e2 commit 45f62c1

File tree

7 files changed

+1270
-0
lines changed

7 files changed

+1270
-0
lines changed

pyproject.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,11 @@ dependencies = [
1111
readme = "README.md"
1212
requires-python = ">= 3.10"
1313

14+
[project.optional-dependencies]
15+
vmath = [
16+
"numpy>=2.2.6",
17+
"scipy>=1.15.3",
18+
]
1419
[build-system]
1520
requires = ["hatchling"]
1621
build-backend = "hatchling.build"
Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
import math as pymath
2+
from typing import TypeVar
3+
4+
from kirin import lowering
5+
from kirin.dialects import ilist
6+
7+
from . import stmts as stmts, interp as interp
8+
from ._dialect import dialect as dialect
9+
10+
pi = pymath.pi
11+
e = pymath.e
12+
tau = pymath.tau
13+
14+
ListLen = TypeVar("ListLen")
15+
16+
17+
@lowering.wraps(stmts.acos)
18+
def acos(x: ilist.IList[float, ListLen]) -> ilist.IList[float, ListLen]: ...
19+
20+
21+
@lowering.wraps(stmts.asin)
22+
def asin(x: ilist.IList[float, ListLen]) -> ilist.IList[float, ListLen]: ...
23+
24+
25+
@lowering.wraps(stmts.asinh)
26+
def asinh(x: ilist.IList[float, ListLen]) -> ilist.IList[float, ListLen]: ...
27+
28+
29+
@lowering.wraps(stmts.atan)
30+
def atan(x: ilist.IList[float, ListLen]) -> ilist.IList[float, ListLen]: ...
31+
32+
33+
@lowering.wraps(stmts.atan2)
34+
def atan2(
35+
y: ilist.IList[float, ListLen], x: ilist.IList[float, ListLen]
36+
) -> ilist.IList[float, ListLen]: ...
37+
38+
39+
@lowering.wraps(stmts.atanh)
40+
def atanh(x: ilist.IList[float, ListLen]) -> ilist.IList[float, ListLen]: ...
41+
42+
43+
@lowering.wraps(stmts.ceil)
44+
def ceil(x: ilist.IList[float, ListLen]) -> ilist.IList[float, ListLen]: ...
45+
46+
47+
@lowering.wraps(stmts.copysign)
48+
def copysign(
49+
x: ilist.IList[float, ListLen], y: ilist.IList[float, ListLen]
50+
) -> ilist.IList[float, ListLen]: ...
51+
52+
53+
@lowering.wraps(stmts.cos)
54+
def cos(x: ilist.IList[float, ListLen]) -> ilist.IList[float, ListLen]: ...
55+
56+
57+
@lowering.wraps(stmts.cosh)
58+
def cosh(x: ilist.IList[float, ListLen]) -> ilist.IList[float, ListLen]: ...
59+
60+
61+
@lowering.wraps(stmts.degrees)
62+
def degrees(x: ilist.IList[float, ListLen]) -> ilist.IList[float, ListLen]: ...
63+
64+
65+
@lowering.wraps(stmts.erf)
66+
def erf(x: ilist.IList[float, ListLen]) -> ilist.IList[float, ListLen]: ...
67+
68+
69+
@lowering.wraps(stmts.erfc)
70+
def erfc(x: ilist.IList[float, ListLen]) -> ilist.IList[float, ListLen]: ...
71+
72+
73+
@lowering.wraps(stmts.exp)
74+
def exp(x: ilist.IList[float, ListLen]) -> ilist.IList[float, ListLen]: ...
75+
76+
77+
@lowering.wraps(stmts.expm1)
78+
def expm1(x: ilist.IList[float, ListLen]) -> ilist.IList[float, ListLen]: ...
79+
80+
81+
@lowering.wraps(stmts.fabs)
82+
def fabs(x: ilist.IList[float, ListLen]) -> ilist.IList[float, ListLen]: ...
83+
84+
85+
@lowering.wraps(stmts.floor)
86+
def floor(x: ilist.IList[float, ListLen]) -> ilist.IList[float, ListLen]: ...
87+
88+
89+
@lowering.wraps(stmts.fmod)
90+
def fmod(
91+
x: ilist.IList[float, ListLen], y: ilist.IList[float, ListLen]
92+
) -> ilist.IList[float, ListLen]: ...
93+
94+
95+
@lowering.wraps(stmts.gamma)
96+
def gamma(x: ilist.IList[float, ListLen]) -> ilist.IList[float, ListLen]: ...
97+
98+
99+
@lowering.wraps(stmts.isfinite)
100+
def isfinite(x: ilist.IList[float, ListLen]) -> ilist.IList[bool, ListLen]: ...
101+
102+
103+
@lowering.wraps(stmts.isinf)
104+
def isinf(x: ilist.IList[float, ListLen]) -> ilist.IList[bool, ListLen]: ...
105+
106+
107+
@lowering.wraps(stmts.isnan)
108+
def isnan(x: ilist.IList[float, ListLen]) -> ilist.IList[bool, ListLen]: ...
109+
110+
111+
@lowering.wraps(stmts.lgamma)
112+
def lgamma(x: ilist.IList[float, ListLen]) -> ilist.IList[float, ListLen]: ...
113+
114+
115+
@lowering.wraps(stmts.log10)
116+
def log10(x: ilist.IList[float, ListLen]) -> ilist.IList[float, ListLen]: ...
117+
118+
119+
@lowering.wraps(stmts.log1p)
120+
def log1p(x: ilist.IList[float, ListLen]) -> ilist.IList[float, ListLen]: ...
121+
122+
123+
@lowering.wraps(stmts.log2)
124+
def log2(x: ilist.IList[float, ListLen]) -> ilist.IList[float, ListLen]: ...
125+
126+
127+
@lowering.wraps(stmts.pow)
128+
def pow(
129+
x: ilist.IList[float, ListLen], y: ilist.IList[float, ListLen]
130+
) -> ilist.IList[float, ListLen]: ...
131+
132+
133+
@lowering.wraps(stmts.radians)
134+
def radians(x: ilist.IList[float, ListLen]) -> ilist.IList[float, ListLen]: ...
135+
136+
137+
@lowering.wraps(stmts.remainder)
138+
def remainder(
139+
x: ilist.IList[float, ListLen], y: ilist.IList[float, ListLen]
140+
) -> ilist.IList[float, ListLen]: ...
141+
142+
143+
@lowering.wraps(stmts.sin)
144+
def sin(x: ilist.IList[float, ListLen]) -> ilist.IList[float, ListLen]: ...
145+
146+
147+
@lowering.wraps(stmts.sinh)
148+
def sinh(x: ilist.IList[float, ListLen]) -> ilist.IList[float, ListLen]: ...
149+
150+
151+
@lowering.wraps(stmts.sqrt)
152+
def sqrt(x: ilist.IList[float, ListLen]) -> ilist.IList[float, ListLen]: ...
153+
154+
155+
@lowering.wraps(stmts.tan)
156+
def tan(x: ilist.IList[float, ListLen]) -> ilist.IList[float, ListLen]: ...
157+
158+
159+
@lowering.wraps(stmts.tanh)
160+
def tanh(x: ilist.IList[float, ListLen]) -> ilist.IList[float, ListLen]: ...
161+
162+
163+
@lowering.wraps(stmts.trunc)
164+
def trunc(x: ilist.IList[float, ListLen]) -> ilist.IList[float, ListLen]: ...
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from kirin import ir
2+
3+
dialect = ir.Dialect("vmath")

0 commit comments

Comments
 (0)