Conversation
Summary of ChangesHello @hikettei, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the polyhedral scheduling DSL by introducing a dedicated Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request introduces a PolyhedralSchedule class and a stmt function to create a more user-friendly DSL for polyhedral scheduling, which is a significant improvement. The refactoring of the conv2d_pool2d_fusion example effectively showcases the benefits of this new API. My review includes a critical fix for a missing function in a new example, along with several medium-severity suggestions to enhance type safety, improve code conciseness by refactoring repetitive logic, and align the new code with the newly introduced DSL guidelines by implementing operator overloading.
* fix: Final formatting fix * test: Accept ruff formatting for test_fusion.py * test: Accept ruff formatting for test_fusion.py (final) * fix: Format domain.py to satisfy ruff * update * fix: Final formatting fix for example * fix: Final format fix accept * update * update
There was a problem hiding this comment.
Code Review
This pull request introduces a PolyhedralSchedule class and a stmt DSL function, which significantly simplify the process of creating and manipulating polyhedral schedules. The refactoring of the conv2d_pool2d_fusion.py example demonstrates the improved readability and usability of the new API. Overall, this is a valuable addition.
However, I've identified a few issues that should be addressed. There's a critical correctness issue in the is_legal method, which could lead to incorrect schedule validation. Additionally, the test coverage threshold has been lowered, and the new modules lack dedicated unit tests, which is a concern for maintainability. I've also noted several opportunities for refactoring to reduce code duplication and to adhere to the newly introduced coding guidelines. Please see my detailed comments below.
| def is_legal(self) -> bool: | ||
| # Check legality against RAW dependencies | ||
| if self.raw_dep: | ||
| return schedule_is_legal_p(self.isl_schedule, self.raw_dep) | ||
| return True |
There was a problem hiding this comment.
The is_legal method currently checks for schedule legality only against Read-After-Write (RAW) dependencies. For a schedule to be truly legal, it must respect all data dependencies, including Write-After-Write (WAW) and Write-After-Read (WAR). The compute_dependencies method already calculates self.total_dep, which is the union of all dependencies. You should use self.total_dep for a correct legality check.
| def is_legal(self) -> bool: | |
| # Check legality against RAW dependencies | |
| if self.raw_dep: | |
| return schedule_is_legal_p(self.isl_schedule, self.raw_dep) | |
| return True | |
| def is_legal(self) -> bool: | |
| # Check legality against all dependencies | |
| if self.total_dep: | |
| return schedule_is_legal_p(self.isl_schedule, self.total_dep) | |
| return True |
|
|
||
| [tool.pytest.ini_options] | ||
| addopts = "--cov=caten --cov-report=term-missing --cov-fail-under=70" | ||
| addopts = "--cov=caten --cov-report=term-missing --cov-fail-under=60" |
There was a problem hiding this comment.
Lowering the test coverage threshold from 70% to 60% is a significant step backward for the project's code quality and maintainability. This change suggests that the new functionality is not adequately tested. Instead of reducing the quality bar, please add the necessary tests to meet the original coverage requirement. This ensures that new code is as reliable as the existing codebase.
| import caten.isl as I | ||
| import caten.polyhedral as P | ||
| from caten.polyhedral.analysis import compute_dependence_relation, schedule_is_legal_p | ||
| from caten.polyhedral.transformations import schedule_node_sequence_full_fuse | ||
|
|
||
|
|
||
| def create_conv_schedule(N, K_out, H_out, W_out, Cin, KH, KW): | ||
| dom_str = f"{{ S_conv[n, k, h, w, c, kh, kw] : 0<=n<{N} and 0<=k<{K_out} and 0<=h<{H_out} and 0<=w<{W_out} and 0<=c<{Cin} and 0<=kh<{KH} and 0<=kw<{KW} }}" | ||
|
|
||
| with P.domain(dom_str) as conv: | ||
| with P.band("{ S_conv[n, k, h, w, c, kh, kw] -> [n, k, h, w, c, kh, kw] }"): | ||
| P.stmt("Out[n, k, h, w] = Out[n, k, h, w], In[n, c, h, w], W[k, c, kh, kw]") | ||
|
|
||
| return conv.finalize() | ||
|
|
||
| def create_pool_schedule(N, K_out, H_pool, W_pool, S_pool, KH_pool, KW_pool): | ||
| dom_str = f"{{ S_pool[n, k, h, w, rh, rw] : 0<=n<{N} and 0<=k<{K_out} and 0<=h<{H_pool} and 0<=w<{W_pool} and 0<=rh<{KH_pool} and 0<=rw<{KW_pool} }}" | ||
|
|
||
| with P.domain(dom_str) as pool: | ||
| with P.band("{ S_pool[n, k, h, w, rh, rw] -> [n, k, h, w, rh, rw] }"): | ||
| P.stmt(f"PoolBuf[n, k, h, w] = PoolBuf[n, k, h, w], Out[n, k, h*{S_pool} + rh, w*{S_pool} + rw]") | ||
|
|
||
| return pool.finalize() | ||
|
|
||
| def test_conv2d_pool2d_fusion(): | ||
| # Parameters | ||
| N = 10 | ||
| Cin = 16 | ||
| Cout = 32 | ||
| H_in, W_in = 32, 32 | ||
|
|
||
| KH_conv, KW_conv = 3, 3 | ||
| S_conv = 1 | ||
| H_conv = (H_in - KH_conv) // S_conv + 1 | ||
| W_conv = (W_in - KW_conv) // S_conv + 1 | ||
|
|
||
| KH_pool, KW_pool = 2, 2 | ||
| S_pool = 2 | ||
| H_pool = (H_conv - KH_pool) // S_pool + 1 | ||
| W_pool = (W_conv - KW_pool) // S_pool + 1 | ||
|
|
||
| Tile_H = S_pool | ||
| Tile_W = S_pool | ||
|
|
||
| # Domains | ||
| conv_dom_str = f"{{ S_conv[n, k, h, w, c, kh, kw] : 0<=n<{N} and 0<=k<{Cout} and 0<=h<{H_conv} and 0<=w<{W_conv} and 0<=c<{Cin} and 0<=kh<{KH_conv} and 0<=kw<{KW_conv} }}" | ||
| pool_dom_str = f"{{ S_pool[n, k, h, w, rh, rw] : 0<=n<{N} and 0<=k<{Cout} and 0<=h<{H_pool} and 0<=w<{W_pool} and 0<=rh<{KH_pool} and 0<=rw<{KW_pool} }}" | ||
|
|
||
| with I.context(): | ||
| # Access Relations (including Reduction Dependencies) | ||
| writes_conv = I.UnionMap( | ||
| f"{{ S_conv[n, k, h, w, c, kh, kw] -> Out[n, k, h, w] : " | ||
| f"0<=n<{N} and 0<=k<{Cout} and 0<=h<{H_conv} and 0<=w<{W_conv} }}" | ||
| ) | ||
| reads_conv_acc = I.UnionMap( | ||
| f"{{ S_conv[n, k, h, w, c, kh, kw] -> Out[n, k, h, w] : " | ||
| f"0<=n<{N} and 0<=k<{Cout} and 0<=h<{H_conv} and 0<=w<{W_conv} }}" | ||
| ) | ||
|
|
||
| reads_pool = I.UnionMap( | ||
| f"{{ S_pool[n, k, h, w, rh, rw] -> Out[n, k, h*{S_pool} + rh, w*{S_pool} + rw] : " | ||
| f"0<=n<{N} and 0<=k<{Cout} and 0<=h<{H_pool} and 0<=w<{W_pool} and 0<=rh<{KH_pool} and 0<=rw<{KW_pool} }}" | ||
| ) | ||
| writes_pool = I.UnionMap( | ||
| "{ S_pool[n, k, h, w, rh, rw] -> PoolBuf[n, k, h, w] }" | ||
| ) | ||
| reads_pool_acc = I.UnionMap( | ||
| "{ S_pool[n, k, h, w, rh, rw] -> PoolBuf[n, k, h, w] }" | ||
| ) | ||
|
|
||
| all_writes = writes_conv.union(writes_pool) | ||
| all_reads = reads_pool.union(reads_conv_acc).union(reads_pool_acc) | ||
|
|
||
| # Initial Schedule | ||
| filters = I.UnionSetList.alloc(2) | ||
| filters = filters.add(I.UnionSet(conv_dom_str)) | ||
| filters = filters.add(I.UnionSet(pool_dom_str)) | ||
|
|
||
| full_dom = I.UnionSet(conv_dom_str).union(I.UnionSet(pool_dom_str)) | ||
| sched = I.Schedule.from_domain(full_dom) | ||
|
|
||
| root = sched.get_root() | ||
| seq_node = root.child(0).insert_sequence(filters) | ||
|
|
||
| conv_filter = seq_node.child(0) | ||
| conv_mupa = I.MultiUnionPwAff.from_union_map(I.UnionMap("{ S_conv[n, k, h, w, c, kh, kw] -> [n, k, h, w, c, kh, kw] }")) | ||
| conv_node = conv_filter.child(0).insert_partial_schedule(conv_mupa) | ||
|
|
||
| seq_node = conv_node.parent().parent() | ||
| pool_filter = seq_node.child(1) | ||
| pool_mupa = I.MultiUnionPwAff.from_union_map(I.UnionMap("{ S_pool[n, k, h, w, rh, rw] -> [n, k, h, w, rh, rw] }")) | ||
| pool_node = pool_filter.child(0).insert_partial_schedule(pool_mupa) | ||
|
|
||
| initial_sched = pool_node.get_schedule() | ||
|
|
||
| total_dep, raw, waw, war = compute_dependence_relation( | ||
| read=all_reads, | ||
| write=all_writes, | ||
| schedule=initial_sched | ||
| ) | ||
|
|
||
| legal = schedule_is_legal_p(initial_sched, total_dep) | ||
| assert legal | ||
|
|
||
| # --- Transformations --- | ||
| def get_seq_from_band(band): | ||
| return band.parent().parent() | ||
|
|
||
| # Split Bands | ||
| seq_node = get_seq_from_band(pool_node) | ||
| conv_band = seq_node.child(0).child(0) | ||
| conv_band = conv_band.band_split(2) | ||
| conv_band_2 = conv_band.child(0) | ||
| conv_band_2 = conv_band_2.band_split(2) | ||
|
|
||
| # Use correct navigation for nested bands | ||
| seq_node = conv_band_2.parent().parent().parent() | ||
|
|
||
| pool_band = seq_node.child(1).child(0) | ||
| pool_band = pool_band.band_split(2) | ||
| pool_band_2 = pool_band.child(0) | ||
| pool_band_2 = pool_band_2.band_split(2) | ||
|
|
||
| seq_node = pool_band_2.parent().parent().parent() | ||
|
|
||
| # Fuse NK | ||
| nk_band = schedule_node_sequence_full_fuse(seq_node) | ||
|
|
||
| # Tile Conv HW | ||
| seq_node = nk_band.child(0) | ||
| conv_hw = seq_node.child(0).child(0) | ||
| conv_hw = conv_hw.band_set_permutable(1) | ||
| space = conv_hw.band_get_space() | ||
|
|
||
| mv = I.MultiVal.zero(space) | ||
| mv = mv.set_val(0, I.Val.int_from_si(Tile_H)) | ||
| mv = mv.set_val(1, I.Val.int_from_si(Tile_W)) | ||
|
|
||
| conv_tiled = conv_hw.band_tile(mv) | ||
| conv_tiled = conv_tiled.band_scale_down(mv) | ||
|
|
||
| # Replace Inner Band | ||
| inner = conv_tiled.child(0) | ||
| replaced = inner.delete() | ||
| new_mupa = I.MultiUnionPwAff.from_union_map(I.UnionMap(f"{{ S_conv[n, k, h, w, c, kh, kw] -> [(h%{Tile_H}), (w%{Tile_W})] }}")) | ||
| conv_tiled_inner = replaced.insert_partial_schedule(new_mupa) # noqa: F841 | ||
|
|
||
| seq_node = conv_tiled_inner.parent().parent().parent() | ||
|
|
||
| # Fuse HW Tiles | ||
| hw_tile_band = schedule_node_sequence_full_fuse(seq_node) | ||
|
|
||
| # Fuse Inner | ||
| seq_node = hw_tile_band.child(0) | ||
| inner_band = schedule_node_sequence_full_fuse(seq_node) | ||
|
|
||
| final_sched = inner_band.get_schedule() | ||
|
|
||
| legal_fused = schedule_is_legal_p(final_sched, total_dep) | ||
| assert legal_fused | ||
|
|
||
| c_code = P.to_c(final_sched) | ||
| assert "S_conv" in c_code | ||
| assert "S_pool" in c_code | ||
| conv = create_conv_schedule(N, Cout, H_conv, W_conv, Cin, KH_conv, KW_conv) | ||
| pool = create_pool_schedule(N, Cout, H_pool, W_pool, S_pool, KH_pool, KW_pool) | ||
|
|
||
| psched = conv.sequence(pool) | ||
|
|
||
| root = psched.get_root() | ||
| seq_node = root.child(0) | ||
|
|
||
| # 1. Split Bands | ||
| conv_band = seq_node.child(0).child(0) | ||
| conv_band = conv_band.band_split(2) # [n, k] | ||
| conv_band_2 = conv_band.child(0) | ||
| conv_band_2 = conv_band_2.band_split(2) # [h, w] | ||
|
|
||
| seq_node = conv_band_2.parent().parent().parent() | ||
|
|
||
| pool_band = seq_node.child(1).child(0) | ||
| pool_band = pool_band.band_split(2) # [n, k] | ||
| pool_band_2 = pool_band.child(0) | ||
| pool_band_2 = pool_band_2.band_split(2) # [h, w] | ||
|
|
||
| seq_node = pool_band_2.parent().parent().parent() | ||
|
|
||
| # 2. Fuse NK | ||
| nk_band = schedule_node_sequence_full_fuse(seq_node) | ||
| psched.update(nk_band) | ||
|
|
||
| # 3. Tile Conv HW | ||
| seq_node = nk_band.child(0) | ||
| conv_hw = seq_node.child(0).child(0) | ||
| conv_hw = conv_hw.band_set_permutable(1) | ||
| space = conv_hw.band_get_space() | ||
| mv = I.MultiVal.zero(space) | ||
| mv = mv.set_val(0, I.Val.int_from_si(Tile_H)) | ||
| mv = mv.set_val(1, I.Val.int_from_si(Tile_W)) | ||
|
|
||
| conv_tiled = conv_hw.band_tile(mv) | ||
| conv_tiled = conv_tiled.band_scale_down(mv) | ||
|
|
||
| # Replace Inner Band with Relative | ||
| inner = conv_tiled.child(0) | ||
| replaced = inner.delete() | ||
| new_mupa = I.MultiUnionPwAff.from_union_map(I.UnionMap(f"{{ S_conv[n, k, h, w, c, kh, kw] -> [(h%{Tile_H}), (w%{Tile_W})] }}")) | ||
| conv_tiled_inner = replaced.insert_partial_schedule(new_mupa) # noqa: F841 | ||
|
|
||
| seq_node = conv_tiled_inner.parent().parent().parent() | ||
|
|
||
| # 4. Fuse HW Tiles | ||
| hw_tile_band = schedule_node_sequence_full_fuse(seq_node) | ||
| psched.update(hw_tile_band) | ||
|
|
||
| # 5. Fuse Inner | ||
| seq_node = hw_tile_band.child(0) | ||
| inner_band = schedule_node_sequence_full_fuse(seq_node) | ||
| psched.update(inner_band) | ||
|
|
||
| assert psched.is_legal() | ||
| c_code = psched.to_c() | ||
| assert "S_conv" in c_code | ||
| assert "S_pool" in c_code |
There was a problem hiding this comment.
This pull request introduces significant new functionality in caten/polyhedral/poly_schedule.py and caten/polyhedral/stmt.py, but it lacks dedicated unit tests for these new modules. The refactored integration test is good for verifying the end-to-end workflow, but it doesn't cover edge cases, error handling, or individual methods of the new classes and functions. The lack of unit tests is likely the reason the coverage threshold was lowered in pyproject.toml. Please add comprehensive unit tests for the new modules to ensure their correctness and robustness.
| def __init__(self) -> None: | ||
| self.current_node: Optional["I.ScheduleNode"] = None | ||
| self.schedule: Optional["I.Schedule"] = None | ||
| self.current_domain: Any = None |
There was a problem hiding this comment.
The type hint for current_domain is Any, which is not very specific. To improve type safety and code clarity, consider using a more specific type. A forward reference Optional["domain"] would be more appropriate here, assuming it refers to the domain class from caten.polyhedral.schedule_tree.domain. You will need to add the import under a TYPE_CHECKING block to avoid circular dependencies.
| self.current_domain: Any = None | |
| self.current_domain: Optional["domain"] = None |
| new_reads = None | ||
| if self.reads and other.reads: | ||
| new_reads = self.reads.union(other.reads) | ||
| elif self.reads: | ||
| new_reads = self.reads | ||
| elif other.reads: | ||
| new_reads = other.reads | ||
|
|
||
| new_writes = None | ||
| if self.writes and other.writes: | ||
| new_writes = self.writes.union(other.writes) | ||
| elif self.writes: | ||
| new_writes = self.writes | ||
| elif other.writes: | ||
| new_writes = other.writes |
There was a problem hiding this comment.
The logic for combining reads and writes maps is verbose and can be simplified. Additionally, the newly added guideline in AGENTS.md recommends using operator overloads like | instead of the .union() method for better readability. This can be refactored to be more concise and adhere to the new guideline.
new_reads = self.reads | other.reads if self.reads and other.reads else self.reads or other.reads
new_writes = self.writes | other.writes if self.writes and other.writes else self.writes or other.writes| builder.current_domain = self._prev_domain | ||
|
|
||
| def finalize(self, read: Optional[Union[str, "I.UnionMap"]] = None, write: Optional[Union[str, "I.UnionMap"]] = None) -> Any: | ||
| from ..poly_schedule import PolyhedralSchedule |
There was a problem hiding this comment.
The import of PolyhedralSchedule is performed inside the finalize method. While this can be a strategy to avoid circular imports, it doesn't seem necessary here as poly_schedule.py does not import domain.py. Moving this import to the top of the file would follow standard Python conventions and improve readability.
| if self.domain_set: | ||
| uset = self.domain_set | ||
| if isinstance(uset, str): | ||
| uset = I.UnionSet(uset) | ||
| elif isinstance(uset, I.Set): | ||
| uset = I.UnionSet.from_set(uset) | ||
| self.schedule = I.Schedule.from_domain(uset) |
There was a problem hiding this comment.
| s_str = str(s) | ||
| if ":" in s_str: | ||
| tuple_part = s_str.split(":")[0].strip() | ||
| if tuple_part.startswith("{"): | ||
| tuple_part = tuple_part[1:].strip() | ||
| else: | ||
| tuple_part = s_str.strip() | ||
| if tuple_part.startswith("{") and tuple_part.endswith("}"): | ||
| tuple_part = tuple_part[1:-1].strip() |
There was a problem hiding this comment.
Extracting the tuple part by parsing the string representation of an isl.Set is fragile. This implementation relies on the specific format of isl.Set.__str__, which could change in future versions of the underlying library, breaking your code. If the caten.isl API provides a more robust way to achieve this (e.g., by inspecting the set's space and dimensions programmatically), it would be better to use that. If not, it would be good to add a comment here highlighting the fragility of this approach.
| if new_reads: | ||
| if dom.reads_map: | ||
| dom.reads_map = dom.reads_map.union(new_reads) | ||
| else: | ||
| dom.reads_map = new_reads | ||
|
|
||
| if new_writes: | ||
| if dom.writes_map: | ||
| dom.writes_map = dom.writes_map.union(new_writes) | ||
| else: | ||
| dom.writes_map = new_writes No newline at end of file |
There was a problem hiding this comment.
The logic for updating dom.reads_map and dom.writes_map is duplicated. This can be refactored into a small helper function to reduce repetition. Also, according to the new guideline in AGENTS.md, you should prefer the | operator over the .union() method.
| if new_reads: | |
| if dom.reads_map: | |
| dom.reads_map = dom.reads_map.union(new_reads) | |
| else: | |
| dom.reads_map = new_reads | |
| if new_writes: | |
| if dom.writes_map: | |
| dom.writes_map = dom.writes_map.union(new_writes) | |
| else: | |
| dom.writes_map = new_writes | |
| if new_reads: | |
| dom.reads_map = dom.reads_map | new_reads if dom.reads_map else new_reads | |
| if new_writes: | |
| dom.writes_map = dom.writes_map | new_writes if dom.writes_map else new_writes |
No description provided.