-
Notifications
You must be signed in to change notification settings - Fork 1
Hikettei/polyhedral schedule #13
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
432bd77
12c665e
3228fc6
bd3823a
4fed54e
a8412cf
602bb84
7cff760
4340545
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,7 +2,7 @@ | |
| __pycache__/ | ||
| *.py[cod] | ||
| *$py.class | ||
|
|
||
| .DS_Store | ||
| # Distribution / packaging | ||
| .Python | ||
| build/ | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,65 @@ | ||||||||||||||||||||||
| from typing import Optional | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| import caten.isl as I | ||||||||||||||||||||||
| from caten.polyhedral.analysis import compute_dependence_relation, schedule_is_legal_p | ||||||||||||||||||||||
| from caten.polyhedral.codegen import to_c | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
|
|
||||||||||||||||||||||
| class PolyhedralSchedule: | ||||||||||||||||||||||
| def __init__(self, schedule: "I.Schedule", reads: Optional["I.UnionMap"] = None, writes: Optional["I.UnionMap"] = None) -> None: | ||||||||||||||||||||||
| self.isl_schedule = schedule | ||||||||||||||||||||||
| self.reads = reads | ||||||||||||||||||||||
| self.writes = writes | ||||||||||||||||||||||
| self.raw_dep: Optional["I.UnionMap"] = None | ||||||||||||||||||||||
| self.total_dep: Optional["I.UnionMap"] = None | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| if reads and writes: | ||||||||||||||||||||||
| self.compute_dependencies() | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| def compute_dependencies(self) -> None: | ||||||||||||||||||||||
| if not self.reads or not self.writes: | ||||||||||||||||||||||
| return | ||||||||||||||||||||||
| total, raw, waw, war = compute_dependence_relation(self.reads, self.writes, self.isl_schedule) | ||||||||||||||||||||||
| self.raw_dep = raw | ||||||||||||||||||||||
| self.total_dep = total | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| def is_legal(self) -> bool: | ||||||||||||||||||||||
| # Check legality against RAW dependencies | ||||||||||||||||||||||
| if self.raw_dep: | ||||||||||||||||||||||
| return schedule_is_legal_p(self.isl_schedule, self.raw_dep) | ||||||||||||||||||||||
| return True | ||||||||||||||||||||||
|
Comment on lines
+26
to
+30
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The
Suggested change
|
||||||||||||||||||||||
|
|
||||||||||||||||||||||
| def get_root(self) -> "I.ScheduleNode": | ||||||||||||||||||||||
| return self.isl_schedule.get_root() | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| def to_c(self) -> str: | ||||||||||||||||||||||
| return to_c(self.isl_schedule) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| def __str__(self) -> str: | ||||||||||||||||||||||
| return str(self.isl_schedule) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| def update(self, node: "I.ScheduleNode") -> None: | ||||||||||||||||||||||
| """Update the internal schedule from a modified schedule node.""" | ||||||||||||||||||||||
| self.isl_schedule = node.get_schedule() | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| def sequence(self, other: "PolyhedralSchedule") -> "PolyhedralSchedule": | ||||||||||||||||||||||
| """Combine this schedule with another using isl_schedule_sequence.""" | ||||||||||||||||||||||
| new_sched = self.isl_schedule.sequence(other.isl_schedule) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| new_reads = None | ||||||||||||||||||||||
| if self.reads and other.reads: | ||||||||||||||||||||||
| new_reads = self.reads.union(other.reads) | ||||||||||||||||||||||
| elif self.reads: | ||||||||||||||||||||||
| new_reads = self.reads | ||||||||||||||||||||||
| elif other.reads: | ||||||||||||||||||||||
| new_reads = other.reads | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| new_writes = None | ||||||||||||||||||||||
| if self.writes and other.writes: | ||||||||||||||||||||||
| new_writes = self.writes.union(other.writes) | ||||||||||||||||||||||
| elif self.writes: | ||||||||||||||||||||||
| new_writes = self.writes | ||||||||||||||||||||||
| elif other.writes: | ||||||||||||||||||||||
| new_writes = other.writes | ||||||||||||||||||||||
|
Comment on lines
+49
to
+63
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The logic for combining new_reads = self.reads | other.reads if self.reads and other.reads else self.reads or other.reads
new_writes = self.writes | other.writes if self.writes and other.writes else self.writes or other.writes |
||||||||||||||||||||||
|
|
||||||||||||||||||||||
| return PolyhedralSchedule(new_sched, reads=new_reads, writes=new_writes) | ||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -157,14 +157,38 @@ def __enter__(self) -> "domain": | |
| # We set current_node to the child of Domain (the Leaf) | ||
| builder.current_node = sched.get_root().child(0) | ||
|
|
||
| self._prev_domain = builder.current_domain | ||
| builder.current_domain = self | ||
|
|
||
| return self | ||
|
|
||
| def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: | ||
| builder = get_builder() | ||
| if builder.current_node: | ||
| self.schedule = builder.current_node.get_schedule() | ||
| builder.current_node = None | ||
| builder.current_domain = self._prev_domain | ||
|
|
||
| def finalize(self, read: Optional[Union[str, "I.UnionMap"]] = None, write: Optional[Union[str, "I.UnionMap"]] = None) -> Any: | ||
| from ..poly_schedule import PolyhedralSchedule | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The import of |
||
|
|
||
| if self.schedule is None: | ||
| if self.domain_set: | ||
| uset = self.domain_set | ||
| if isinstance(uset, str): | ||
| uset = I.UnionSet(uset) | ||
| elif isinstance(uset, I.Set): | ||
| uset = I.UnionSet.from_set(uset) | ||
| self.schedule = I.Schedule.from_domain(uset) | ||
|
Comment on lines
+176
to
+182
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| else: | ||
| raise RuntimeError("No domain set for schedule.") | ||
|
|
||
| r = read if read else self.reads_map | ||
| if isinstance(r, str): | ||
| r = I.UnionMap(r) | ||
|
|
||
| w = write if write else self.writes_map | ||
| if isinstance(w, str): | ||
| w = I.UnionMap(w) | ||
|
|
||
| def finalize(self, op_context: Any = None) -> Any: | ||
| # Placeholder for Kernel creation logic | ||
| return self.schedule | ||
| return PolyhedralSchedule(self.schedule, reads=r, writes=w) | ||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,77 @@ | ||||||||||||||||||||||||||||||||||
| import re | ||||||||||||||||||||||||||||||||||
| from typing import List, Optional, Tuple | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| import caten.isl as I | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| from .context import get_builder | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| def stmt(expr: str) -> None: | ||||||||||||||||||||||||||||||||||
| dom = get_builder().current_domain | ||||||||||||||||||||||||||||||||||
| if dom is None: | ||||||||||||||||||||||||||||||||||
| raise RuntimeError("stmt() must be used within a P.domain context") | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| if "=" not in expr: | ||||||||||||||||||||||||||||||||||
| raise ValueError(f"Invalid statement expression (must contain assignment '='): {expr}") | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| lhs_str, rhs_str = expr.split("=", 1) | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| def extract_accesses(s: str) -> List[Tuple[str, str]]: | ||||||||||||||||||||||||||||||||||
| return re.findall(r"([a-zA-Z_]\w*)\s*\[(.*?)\]", s) | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| writes = extract_accesses(lhs_str) | ||||||||||||||||||||||||||||||||||
| reads = extract_accesses(rhs_str) | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| uset = dom.domain_set | ||||||||||||||||||||||||||||||||||
| if isinstance(uset, str): | ||||||||||||||||||||||||||||||||||
| uset = I.UnionSet(uset) | ||||||||||||||||||||||||||||||||||
| elif isinstance(uset, I.Set): | ||||||||||||||||||||||||||||||||||
| uset = I.UnionSet.from_set(uset) | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| new_reads: Optional["I.UnionMap"] = None | ||||||||||||||||||||||||||||||||||
| new_writes: Optional["I.UnionMap"] = None | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| def process_set(s: "I.Set") -> None: | ||||||||||||||||||||||||||||||||||
| nonlocal new_reads, new_writes | ||||||||||||||||||||||||||||||||||
| s_str = str(s) | ||||||||||||||||||||||||||||||||||
| if ":" in s_str: | ||||||||||||||||||||||||||||||||||
| tuple_part = s_str.split(":")[0].strip() | ||||||||||||||||||||||||||||||||||
| if tuple_part.startswith("{"): | ||||||||||||||||||||||||||||||||||
| tuple_part = tuple_part[1:].strip() | ||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||
| tuple_part = s_str.strip() | ||||||||||||||||||||||||||||||||||
| if tuple_part.startswith("{") and tuple_part.endswith("}"): | ||||||||||||||||||||||||||||||||||
| tuple_part = tuple_part[1:-1].strip() | ||||||||||||||||||||||||||||||||||
|
Comment on lines
+36
to
+44
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Extracting the tuple part by parsing the string representation of an |
||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| for (name, indices) in writes: | ||||||||||||||||||||||||||||||||||
| m_str = f"{{ {tuple_part} -> {name}[{indices}] }}" | ||||||||||||||||||||||||||||||||||
| m = I.UnionMap(m_str) | ||||||||||||||||||||||||||||||||||
| if new_writes is None: | ||||||||||||||||||||||||||||||||||
| new_writes = m | ||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||
| new_writes = new_writes.union(m) | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| for (name, indices) in reads: | ||||||||||||||||||||||||||||||||||
| m_str = f"{{ {tuple_part} -> {name}[{indices}] }}" | ||||||||||||||||||||||||||||||||||
| m = I.UnionMap(m_str) | ||||||||||||||||||||||||||||||||||
| if new_reads is None: | ||||||||||||||||||||||||||||||||||
| new_reads = m | ||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||
| new_reads = new_reads.union(m) | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| set_list = uset.get_set_list() | ||||||||||||||||||||||||||||||||||
| n = set_list.n_set() | ||||||||||||||||||||||||||||||||||
| for i in range(n): | ||||||||||||||||||||||||||||||||||
| process_set(set_list.get_at(i)) | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| if new_reads: | ||||||||||||||||||||||||||||||||||
| if dom.reads_map: | ||||||||||||||||||||||||||||||||||
| dom.reads_map = dom.reads_map.union(new_reads) | ||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||
| dom.reads_map = new_reads | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| if new_writes: | ||||||||||||||||||||||||||||||||||
| if dom.writes_map: | ||||||||||||||||||||||||||||||||||
| dom.writes_map = dom.writes_map.union(new_writes) | ||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||
| dom.writes_map = new_writes | ||||||||||||||||||||||||||||||||||
|
Comment on lines
+67
to
+77
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The logic for updating
Suggested change
|
||||||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The type hint for
current_domainisAny, which is not very specific. To improve type safety and code clarity, consider using a more specific type. A forward referenceOptional["domain"]would be more appropriate here, assuming it refers to thedomainclass fromcaten.polyhedral.schedule_tree.domain. You will need to add the import under aTYPE_CHECKINGblock to avoid circular dependencies.