Skip to content

Commit a814fd4

Browse files
committed
adds skeleton for ISL operation pool
1 parent 3fe69e0 commit a814fd4

File tree

1 file changed

+222
-0
lines changed

1 file changed

+222
-0
lines changed

islpy/oppool.py

Lines changed: 222 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,222 @@
1+
from __future__ import annotations
2+
3+
__copyright__ = "Copyright (C) 2021 Kaushik Kulkarni"
4+
5+
__license__ = """
6+
Permission is hereby granted, free of charge, to any person obtaining a copy
7+
of this software and associated documentation files (the "Software"), to deal
8+
in the Software without restriction, including without limitation the rights
9+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10+
copies of the Software, and to permit persons to whom the Software is
11+
furnished to do so, subject to the following conditions:
12+
13+
The above copyright notice and this permission notice shall be included in
14+
all copies or substantial portions of the Software.
15+
16+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
22+
THE SOFTWARE.
23+
"""
24+
25+
26+
from functools import cached_property
27+
import islpy as isl
28+
from dataclasses import dataclass, field
29+
from typing import Union, Dict, Any, Optional, Tuple
30+
from pytools import UniqueNameGenerator
31+
32+
33+
BaseType = Union[isl.Aff, isl.BasicSet, isl.Set, isl.BasicMap, isl.Map]
34+
BASE_CLASSES = (isl.Aff, isl.BasicSet, isl.Set, isl.BasicMap, isl.Map)
35+
36+
37+
def normalize(obj: BaseType) -> BaseType:
38+
vng = UniqueNameGenerator(forced_prefix="_islpy")
39+
40+
lift_map = {}
41+
new_obj = obj
42+
43+
for old_name, (dt, pos) in obj.get_var_dict().items():
44+
if dt == isl.dim_type.param:
45+
new_name = vng("param")
46+
elif dt == isl.dim_type.set:
47+
new_name = vng("set")
48+
elif dt == isl.dim_type.in_:
49+
new_name = vng("in")
50+
elif dt == isl.dim_type.out:
51+
new_name = vng("out")
52+
else:
53+
raise NotImplementedError(dt)
54+
55+
new_obj = new_obj.set_dim_name(dt, pos, new_name)
56+
lift_map[new_name] = old_name
57+
58+
return new_obj, lift_map
59+
60+
61+
@dataclass
62+
class NormalizedISLObj:
63+
ground_obj: BaseType
64+
lift_map: Dict[str, str]
65+
66+
def lift(self) -> BaseType:
67+
new_obj = self.ground_obj
68+
69+
for old_name, (dt, pos) in new_obj.get_var_dict().items():
70+
new_obj = new_obj.set_dim_name(dt, pos, self.lift_map[old_name])
71+
72+
return new_obj
73+
74+
@cached_property
75+
def unlift_map(self) -> Dict[str, str]:
76+
return {v: k for k, v in self.lift_map.items()}
77+
78+
def copy(self, ground_obj: Optional[BaseType] = None,
79+
lift_map: Optional[Dict[str, str]] = None) -> NormalizedISLObj:
80+
if ground_obj is None:
81+
ground_obj = self.ground_obj
82+
if lift_map is None:
83+
lift_map = self.lift_map.copy()
84+
85+
return type(self)(ground_obj, lift_map)
86+
87+
def post_init(self):
88+
def _no_user(id: isl.Id):
89+
try:
90+
isl.user
91+
except TypeError:
92+
return True
93+
else:
94+
return False
95+
96+
assert all(_no_user(id_) for id_ in self.ground_obj.get_id_dict())
97+
98+
def get_dim_name(self, op_pool: ISLOpMemoizer,
99+
type: isl.dim_type, pos: int) -> str:
100+
base_name = op_pool(type(self.ground_obj).get_dim_name,
101+
(self.ground_obj, pos))
102+
return self.lift_map[base_name]
103+
104+
def set_dim_name(self, op_pool: ISLOpMemoizer,
105+
type: isl.dim_type, pos: int,
106+
s: str) -> NormalizedISLObj:
107+
base_name = op_pool(type(self.ground_obj).get_dim_name,
108+
(self.ground_obj, pos))
109+
lift_map = self.lift_map.copy()
110+
lift_map[base_name] = s
111+
return self.copy(lift_map=lift_map)
112+
113+
def get_dim_id(self, op_pool: ISLOpMemoizer,
114+
type: isl.dim_type, pos: int) -> isl.Id:
115+
base_name = op_pool(type(self.ground_obj).get_dim_name,
116+
(self.ground_obj, pos))
117+
return isl.Id(self.lift_map[base_name])
118+
119+
def set_dim_id(self, op_pool: ISLOpMemoizer,
120+
type: isl.dim_type, pos: int,
121+
id: isl.Id) -> NormalizedISLObj:
122+
try:
123+
id.user
124+
except TypeError:
125+
pass
126+
else:
127+
raise ValueError("Normalized ISL object cannot have user object in"
128+
"ids.")
129+
return self.set_dim_name(op_pool, type, pos, id.get_name())
130+
131+
def get_id_dict(self, op_pool: ISLOpMemoizer):
132+
ground_dict = op_pool(type(self.ground_obj).get_id_dict, (self.ground_obj,))
133+
return {isl.Id(self.lift_map[k.name]): v for k, v in ground_dict.items()}
134+
135+
def get_var_dict(self, op_pool: ISLOpMemoizer):
136+
ground_dict = op_pool(type(self.ground_obj).get_var_dict, (self.ground_obj,))
137+
return {self.lift_map[k]: v for k, v in ground_dict.items()}
138+
139+
140+
class NormalizedISLBasicSet(NormalizedISLObj):
141+
@staticmethod
142+
def read_from_str(ctx: isl.Context, s: str) -> NormalizedISLBasicSet:
143+
ground_obj, lift_map = normalize(isl.BasicSet(s))
144+
return NormalizedISLBasicSet(ground_obj, lift_map)
145+
146+
def intersect(self, op_pool: ISLOpMemoizer,
147+
other: NormalizedISLBasicSet) -> NormalizedISLBasicSet:
148+
if self.lift_map != other.lift_map:
149+
raise ValueError("spaces don't match")
150+
res_ground = op_pool(isl.BasicSet.intersect,
151+
(self.ground_obj, other.ground_obj))
152+
return NormalizedISLBasicSet(res_ground, self.lift_map)
153+
154+
155+
class NormalizedISLSet(NormalizedISLObj):
156+
@staticmethod
157+
def read_from_str(ctx: isl.Context, s: str) -> NormalizedISLSet:
158+
ground_obj, lift_map = normalize(isl.Set(s))
159+
return NormalizedISLSet(ground_obj, lift_map)
160+
161+
def intersect(self, op_pool: ISLOpMemoizer,
162+
other: NormalizedISLSet) -> NormalizedISLBasicSet:
163+
if self.lift_map != other.lift_map:
164+
raise ValueError("spaces don't match")
165+
res_ground = op_pool(isl.Set.intersect,
166+
(self.ground_obj, other.ground_obj))
167+
return NormalizedISLBasicSet(res_ground, self.lift_map)
168+
169+
170+
class NormalizedISLBasicMap(NormalizedISLObj):
171+
@staticmethod
172+
def read_from_str(ctx: isl.Context, s: str) -> NormalizedISLBasicMap:
173+
ground_obj, lift_map = normalize(isl.BasicMap(s))
174+
return NormalizedISLBasicMap(ground_obj, lift_map)
175+
176+
177+
class NormalizedISLMap(NormalizedISLObj):
178+
@staticmethod
179+
def read_from_str(ctx: isl.Context, s: str) -> NormalizedISLMap:
180+
ground_obj, lift_map = normalize(isl.Map(s))
181+
return NormalizedISLMap(ground_obj, lift_map)
182+
183+
184+
class NormalizedISLAff(NormalizedISLObj):
185+
@staticmethod
186+
def read_from_str(ctx: isl.Context, s: str) -> NormalizedISLAff:
187+
ground_obj, lift_map = normalize(isl.Aff(s))
188+
return NormalizedISLAff(ground_obj, lift_map)
189+
190+
191+
class NormalizedISLPwAff(NormalizedISLObj):
192+
@staticmethod
193+
def read_from_str(ctx: isl.Context, s: str) -> NormalizedISLPwAff:
194+
ground_obj, lift_map = normalize(isl.PwAff(s))
195+
return NormalizedISLPwAff(ground_obj, lift_map)
196+
197+
198+
class NormalizedISLQPolynomial(NormalizedISLObj):
199+
@staticmethod
200+
def read_from_str(ctx: isl.Context, s: str) -> NormalizedISLQPolynomial:
201+
ground_obj, lift_map = normalize(isl.QPolynomial(s))
202+
return NormalizedISLQPolynomial(ground_obj, lift_map)
203+
204+
205+
class NormalizedISLPwQPolynomial(NormalizedISLObj):
206+
@staticmethod
207+
def read_from_str(ctx: isl.Context, s: str) -> NormalizedISLPwQPolynomial:
208+
ground_obj, lift_map = normalize(isl.PwQPolynomial(s))
209+
return NormalizedISLPwQPolynomial(ground_obj, lift_map)
210+
211+
212+
@dataclass
213+
class ISLOpMemoizer:
214+
cache: Dict[Any, Any] = field(default_factory=dict)
215+
216+
def __call__(self, f, args: Tuple[Any, ...]):
217+
try:
218+
return self.cache[(f, args)]
219+
except KeyError:
220+
result = f(*args)
221+
self.cache[(f, args)] = result
222+
return result

0 commit comments

Comments
 (0)