Skip to content

Commit b2b83b5

Browse files
authored
Hikettei/polyhedral schedule (#13)
* fix: Final formatting fix * test: Accept ruff formatting for test_fusion.py * test: Accept ruff formatting for test_fusion.py (final) * fix: Format domain.py to satisfy ruff * update * fix: Final formatting fix for example * fix: Final format fix accept * update * update
1 parent be10bd4 commit b2b83b5

File tree

10 files changed

+358
-266
lines changed

10 files changed

+358
-266
lines changed

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
__pycache__/
33
*.py[cod]
44
*$py.class
5-
5+
.DS_Store
66
# Distribution / packaging
77
.Python
88
build/

AGENTS.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@
3535
- 既存コードへの余計な修正禁止。Set 単体で完結する API は `caten/isl/specs/set.py` 内で完結させる。UnionSet 依存 API は関数名だけ置いて保留可。自動生成手法は禁止。
3636
- 進捗と作業計画を常に本ファイルに記録し更新すること(型ごとに完了状況や今後の順番を明記)。最新の計画がここに存在する状態を保つ。
3737

38+
## Polyhedral DSL Guidelines
39+
- Prefer using Mixin operator overloads (e.g., `A | B` instead of `A.union(B)`) for cleaner code in user scripts and DSL implementations.
40+
3841
## 作業計画と進捗 (2025-11-16)
3942
直近のギャップ集計: `docs/ISL_missing_apis.md`(2025-11-16 再生成、欠落API 2047件)。map 残 2 件(tuple_name系シンボル未提供のみ、libisl非存在)。
4043
優先順とステータス(✅完了 / 🚧着手中 / ⏳未着手)

caten/polyhedral/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from .schedule_tree.filter import filter
77
from .schedule_tree.mark import mark
88
from .schedule_tree.sequence import sequence
9+
from .stmt import stmt
910

1011
__all__ = [
1112
"domain",
@@ -16,4 +17,5 @@
1617
"schedule",
1718
"compute_flow",
1819
"to_c",
20+
"stmt",
1921
]

caten/polyhedral/context.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
import contextvars
4-
from typing import TYPE_CHECKING, Optional
4+
from typing import TYPE_CHECKING, Any, Optional
55

66
if TYPE_CHECKING:
77
import caten.isl as I
@@ -10,6 +10,7 @@ class ScheduleBuilder:
1010
def __init__(self) -> None:
1111
self.current_node: Optional["I.ScheduleNode"] = None
1212
self.schedule: Optional["I.Schedule"] = None
13+
self.current_domain: Any = None
1314

1415
_builder_ctx: contextvars.ContextVar[Optional[ScheduleBuilder]] = contextvars.ContextVar("schedule_builder", default=None)
1516

caten/polyhedral/poly_schedule.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
from typing import Optional
2+
3+
import caten.isl as I
4+
from caten.polyhedral.analysis import compute_dependence_relation, schedule_is_legal_p
5+
from caten.polyhedral.codegen import to_c
6+
7+
8+
class PolyhedralSchedule:
9+
def __init__(self, schedule: "I.Schedule", reads: Optional["I.UnionMap"] = None, writes: Optional["I.UnionMap"] = None) -> None:
10+
self.isl_schedule = schedule
11+
self.reads = reads
12+
self.writes = writes
13+
self.raw_dep: Optional["I.UnionMap"] = None
14+
self.total_dep: Optional["I.UnionMap"] = None
15+
16+
if reads and writes:
17+
self.compute_dependencies()
18+
19+
def compute_dependencies(self) -> None:
20+
if not self.reads or not self.writes:
21+
return
22+
total, raw, waw, war = compute_dependence_relation(self.reads, self.writes, self.isl_schedule)
23+
self.raw_dep = raw
24+
self.total_dep = total
25+
26+
def is_legal(self) -> bool:
27+
# Check legality against RAW dependencies
28+
if self.raw_dep:
29+
return schedule_is_legal_p(self.isl_schedule, self.raw_dep)
30+
return True
31+
32+
def get_root(self) -> "I.ScheduleNode":
33+
return self.isl_schedule.get_root()
34+
35+
def to_c(self) -> str:
36+
return to_c(self.isl_schedule)
37+
38+
def __str__(self) -> str:
39+
return str(self.isl_schedule)
40+
41+
def update(self, node: "I.ScheduleNode") -> None:
42+
"""Update the internal schedule from a modified schedule node."""
43+
self.isl_schedule = node.get_schedule()
44+
45+
def sequence(self, other: "PolyhedralSchedule") -> "PolyhedralSchedule":
46+
"""Combine this schedule with another using isl_schedule_sequence."""
47+
new_sched = self.isl_schedule.sequence(other.isl_schedule)
48+
49+
new_reads = None
50+
if self.reads and other.reads:
51+
new_reads = self.reads.union(other.reads)
52+
elif self.reads:
53+
new_reads = self.reads
54+
elif other.reads:
55+
new_reads = other.reads
56+
57+
new_writes = None
58+
if self.writes and other.writes:
59+
new_writes = self.writes.union(other.writes)
60+
elif self.writes:
61+
new_writes = self.writes
62+
elif other.writes:
63+
new_writes = other.writes
64+
65+
return PolyhedralSchedule(new_sched, reads=new_reads, writes=new_writes)

caten/polyhedral/schedule_tree/domain.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -157,14 +157,38 @@ def __enter__(self) -> "domain":
157157
# We set current_node to the child of Domain (the Leaf)
158158
builder.current_node = sched.get_root().child(0)
159159

160+
self._prev_domain = builder.current_domain
161+
builder.current_domain = self
162+
160163
return self
161164

162165
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
163166
builder = get_builder()
164167
if builder.current_node:
165168
self.schedule = builder.current_node.get_schedule()
166169
builder.current_node = None
170+
builder.current_domain = self._prev_domain
171+
172+
def finalize(self, read: Optional[Union[str, "I.UnionMap"]] = None, write: Optional[Union[str, "I.UnionMap"]] = None) -> Any:
173+
from ..poly_schedule import PolyhedralSchedule
174+
175+
if self.schedule is None:
176+
if self.domain_set:
177+
uset = self.domain_set
178+
if isinstance(uset, str):
179+
uset = I.UnionSet(uset)
180+
elif isinstance(uset, I.Set):
181+
uset = I.UnionSet.from_set(uset)
182+
self.schedule = I.Schedule.from_domain(uset)
183+
else:
184+
raise RuntimeError("No domain set for schedule.")
185+
186+
r = read if read else self.reads_map
187+
if isinstance(r, str):
188+
r = I.UnionMap(r)
189+
190+
w = write if write else self.writes_map
191+
if isinstance(w, str):
192+
w = I.UnionMap(w)
167193

168-
def finalize(self, op_context: Any = None) -> Any:
169-
# Placeholder for Kernel creation logic
170-
return self.schedule
194+
return PolyhedralSchedule(self.schedule, reads=r, writes=w)

caten/polyhedral/stmt.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
import re
2+
from typing import List, Optional, Tuple
3+
4+
import caten.isl as I
5+
6+
from .context import get_builder
7+
8+
9+
def stmt(expr: str) -> None:
10+
dom = get_builder().current_domain
11+
if dom is None:
12+
raise RuntimeError("stmt() must be used within a P.domain context")
13+
14+
if "=" not in expr:
15+
raise ValueError(f"Invalid statement expression (must contain assignment '='): {expr}")
16+
17+
lhs_str, rhs_str = expr.split("=", 1)
18+
19+
def extract_accesses(s: str) -> List[Tuple[str, str]]:
20+
return re.findall(r"([a-zA-Z_]\w*)\s*\[(.*?)\]", s)
21+
22+
writes = extract_accesses(lhs_str)
23+
reads = extract_accesses(rhs_str)
24+
25+
uset = dom.domain_set
26+
if isinstance(uset, str):
27+
uset = I.UnionSet(uset)
28+
elif isinstance(uset, I.Set):
29+
uset = I.UnionSet.from_set(uset)
30+
31+
new_reads: Optional["I.UnionMap"] = None
32+
new_writes: Optional["I.UnionMap"] = None
33+
34+
def process_set(s: "I.Set") -> None:
35+
nonlocal new_reads, new_writes
36+
s_str = str(s)
37+
if ":" in s_str:
38+
tuple_part = s_str.split(":")[0].strip()
39+
if tuple_part.startswith("{"):
40+
tuple_part = tuple_part[1:].strip()
41+
else:
42+
tuple_part = s_str.strip()
43+
if tuple_part.startswith("{") and tuple_part.endswith("}"):
44+
tuple_part = tuple_part[1:-1].strip()
45+
46+
for (name, indices) in writes:
47+
m_str = f"{{ {tuple_part} -> {name}[{indices}] }}"
48+
m = I.UnionMap(m_str)
49+
if new_writes is None:
50+
new_writes = m
51+
else:
52+
new_writes = new_writes.union(m)
53+
54+
for (name, indices) in reads:
55+
m_str = f"{{ {tuple_part} -> {name}[{indices}] }}"
56+
m = I.UnionMap(m_str)
57+
if new_reads is None:
58+
new_reads = m
59+
else:
60+
new_reads = new_reads.union(m)
61+
62+
set_list = uset.get_set_list()
63+
n = set_list.n_set()
64+
for i in range(n):
65+
process_set(set_list.get_at(i))
66+
67+
if new_reads:
68+
if dom.reads_map:
69+
dom.reads_map = dom.reads_map.union(new_reads)
70+
else:
71+
dom.reads_map = new_reads
72+
73+
if new_writes:
74+
if dom.writes_map:
75+
dom.writes_map = dom.writes_map.union(new_writes)
76+
else:
77+
dom.writes_map = new_writes

0 commit comments

Comments
 (0)