1+ import re
2+ from typing import List , Optional , Tuple
3+
4+ import caten .isl as I
5+
6+ from .context import get_builder
7+
8+
9+ def stmt (expr : str ) -> None :
10+ dom = get_builder ().current_domain
11+ if dom is None :
12+ raise RuntimeError ("stmt() must be used within a P.domain context" )
13+
14+ if "=" not in expr :
15+ raise ValueError (f"Invalid statement expression (must contain assignment '='): { expr } " )
16+
17+ lhs_str , rhs_str = expr .split ("=" , 1 )
18+
19+ def extract_accesses (s : str ) -> List [Tuple [str , str ]]:
20+ return re .findall (r"([a-zA-Z_]\w*)\s*\[(.*?)\]" , s )
21+
22+ writes = extract_accesses (lhs_str )
23+ reads = extract_accesses (rhs_str )
24+
25+ uset = dom .domain_set
26+ if isinstance (uset , str ):
27+ uset = I .UnionSet (uset )
28+ elif isinstance (uset , I .Set ):
29+ uset = I .UnionSet .from_set (uset )
30+
31+ new_reads : Optional ["I.UnionMap" ] = None
32+ new_writes : Optional ["I.UnionMap" ] = None
33+
34+ def process_set (s : "I.Set" ) -> None :
35+ nonlocal new_reads , new_writes
36+ s_str = str (s )
37+ if ":" in s_str :
38+ tuple_part = s_str .split (":" )[0 ].strip ()
39+ if tuple_part .startswith ("{" ):
40+ tuple_part = tuple_part [1 :].strip ()
41+ else :
42+ tuple_part = s_str .strip ()
43+ if tuple_part .startswith ("{" ) and tuple_part .endswith ("}" ):
44+ tuple_part = tuple_part [1 :- 1 ].strip ()
45+
46+ for (name , indices ) in writes :
47+ m_str = f"{{ { tuple_part } -> { name } [{ indices } ] }}"
48+ m = I .UnionMap (m_str )
49+ if new_writes is None :
50+ new_writes = m
51+ else :
52+ new_writes = new_writes .union (m )
53+
54+ for (name , indices ) in reads :
55+ m_str = f"{{ { tuple_part } -> { name } [{ indices } ] }}"
56+ m = I .UnionMap (m_str )
57+ if new_reads is None :
58+ new_reads = m
59+ else :
60+ new_reads = new_reads .union (m )
61+
62+ set_list = uset .get_set_list ()
63+ n = set_list .n_set ()
64+ for i in range (n ):
65+ process_set (set_list .get_at (i ))
66+
67+ if new_reads :
68+ if dom .reads_map :
69+ dom .reads_map = dom .reads_map .union (new_reads )
70+ else :
71+ dom .reads_map = new_reads
72+
73+ if new_writes :
74+ if dom .writes_map :
75+ dom .writes_map = dom .writes_map .union (new_writes )
76+ else :
77+ dom .writes_map = new_writes
0 commit comments