Conversation
…STVisitor Renderer, Sequence Schedule. Fix e2e_matmul.
|
🤖 Hi @hikettei, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
Summary of ChangesHello @hikettei, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request fundamentally refactors the compiler's core, introducing a robust framework for defining and optimizing tensor computations. It establishes a new Intermediate Representation (IR) and integrates a polyhedral model for advanced loop transformations, enabling efficient code generation for various targets. The changes provide a powerful and flexible foundation for future compiler development. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
|
🤖 I'm sorry @hikettei, but I was unable to process your request. Please see the logs for more details. |
There was a problem hiding this comment.
Code Review
This pull request introduces a significant and well-designed infrastructure for a new frontend DSL and IR, including tracing, a pattern matcher, and integration with a polyhedral backend. This is a solid foundation for the project. My review focuses on a critical thread-safety issue in the tracing mechanism, a few functional gaps between the frontend API and the backend implementation, and several opportunities for improving maintainability and performance.
| def unroll(factor: int = 4) -> Directive: return Directive("unroll", (factor,)) | ||
|
|
||
| # --- Range --- | ||
| _range_counter = 0 |
There was a problem hiding this comment.
The global _range_counter is not thread-safe. If two kernels are traced in parallel threads, they will share and modify this counter, leading to a race condition. This should be moved into the GraphBuilder class in caten/trace.py to make it thread-local. This is the first of several changes to address this.
| global _range_counter | ||
| self.args = args | ||
| self.iter_sym = Symbol(f"i{_range_counter}") | ||
| self.directives: List[Directive] = [] | ||
| _range_counter += 1 |
There was a problem hiding this comment.
To fix the thread-safety issue with _range_counter, this logic should be updated to use a counter from the GraphBuilder instance instead of a global variable. This change depends on adding range_counter to the GraphBuilder class.
| global _range_counter | |
| self.args = args | |
| self.iter_sym = Symbol(f"i{_range_counter}") | |
| self.directives: List[Directive] = [] | |
| _range_counter += 1 | |
| builder = get_builder() | |
| self.args = args | |
| self.iter_sym = Symbol(f"i{builder.range_counter}") | |
| self.directives: List[Directive] = [] | |
| builder.range_counter += 1 |
| global _range_counter | ||
| _range_counter = 0 |
| class GraphBuilder: | ||
| def __init__(self) -> None: | ||
| # The root block of the graph | ||
| self.root_block: List[Node] = [] | ||
| # Stack of active blocks (lists of nodes) where new nodes are appended | ||
| self.block_stack: List[List[Node]] = [self.root_block] | ||
| # Inputs to the kernel | ||
| self.inputs: List[Node] = [] |
There was a problem hiding this comment.
To fix a thread-safety issue with loop variable naming, a counter for ranges should be stored here in GraphBuilder instead of as a global variable. This makes the counter thread-local via contextvars.
| class GraphBuilder: | |
| def __init__(self) -> None: | |
| # The root block of the graph | |
| self.root_block: List[Node] = [] | |
| # Stack of active blocks (lists of nodes) where new nodes are appended | |
| self.block_stack: List[List[Node]] = [self.root_block] | |
| # Inputs to the kernel | |
| self.inputs: List[Node] = [] | |
| class GraphBuilder: | |
| def __init__(self) -> None: | |
| # The root block of the graph | |
| self.root_block: List[Node] = [] | |
| # Stack of active blocks (lists of nodes) where new nodes are appended | |
| self.block_stack: List[List[Node]] = [self.root_block] | |
| # Inputs to the kernel | |
| self.inputs: List[Node] = [] | |
| self.range_counter: int = 0 |
| def reset(self) -> None: | ||
| self.root_block = [] | ||
| self.block_stack = [self.root_block] | ||
| self.inputs = [] |
There was a problem hiding this comment.
The range_counter should be reset here to ensure thread-safety and correctness for subsequent kernel traces.
| def reset(self) -> None: | |
| self.root_block = [] | |
| self.block_stack = [self.root_block] | |
| self.inputs = [] | |
| def reset(self) -> None: | |
| self.root_block = [] | |
| self.block_stack = [self.root_block] | |
| self.inputs = [] | |
| self.range_counter = 0 |
| except Exception as e: | ||
| print(f"WARNING: Failed to insert partial schedule: {e} for {mupa_str}") |
There was a problem hiding this comment.
Catching a broad Exception and only printing a warning can hide critical errors during schedule construction. If creating the MultiUnionPwAff fails, it likely indicates a fundamental problem with the SCoP extraction or the mupa_str formatting. This error should be propagated to aid debugging.
| except Exception as e: | |
| print(f"WARNING: Failed to insert partial schedule: {e} for {mupa_str}") | |
| except Exception as e: | |
| raise RuntimeError(f"Failed to create partial schedule from: {mupa_str}") from e |
| # Note: AstToGraphConverter does NOT preserve original directives attached to range(), | ||
| # because they are lost when converting to ISL AST (unless marked). | ||
| # To support directives with Polyhedral model, we need to add marks in Schedule tree. | ||
|
|
||
| # For now, if we want directives to appear, we rely on the fact that user manually adds them via schedule API, | ||
| # OR we implement a mechanism to carry them over. | ||
| # The user request "with (C.range(10) | C.parallel())" implies they want it in the final code. | ||
| # Since we reconstruct graph from ISL AST, these directives are currently LOST. |
There was a problem hiding this comment.
As noted in the comments, directives like parallel() are currently lost during the polyhedral compilation pipeline (graph -> SCoP -> ISL AST -> graph). This means that although the frontend syntax is present, it has no effect on the final generated code. This is a significant functional gap that could be misleading to users.
To fix this, the directives need to be propagated through the polyhedral model. A common way to do this is to associate them with loops during SCoP construction and then use isl_schedule_node_insert_mark to insert mark nodes into the ISL schedule tree. The AstToGraphConverter would then need to recognize these marks and re-create the directive nodes in the final graph.
| __all__ = [ | ||
| "vars", | ||
| "range", | ||
| "kernel", | ||
| "Tensor", | ||
| "TensorSpec", | ||
| "float32", | ||
| "int32", | ||
| "f32", | ||
| "i32", | ||
| "DType", | ||
| "when", | ||
| "parallel", | ||
| "vectorize", | ||
| "unroll", | ||
| ] |
There was a problem hiding this comment.
The __all__ list is not alphabetically sorted. Sorting it improves readability and maintainability, making it easier to see what is exported from the module.
| __all__ = [ | |
| "vars", | |
| "range", | |
| "kernel", | |
| "Tensor", | |
| "TensorSpec", | |
| "float32", | |
| "int32", | |
| "f32", | |
| "i32", | |
| "DType", | |
| "when", | |
| "parallel", | |
| "vectorize", | |
| "unroll", | |
| ] | |
| __all__ = [ | |
| "DType", | |
| "Tensor", | |
| "TensorSpec", | |
| "f32", | |
| "float32", | |
| "i32", | |
| "int32", | |
| "kernel", | |
| "parallel", | |
| "range", | |
| "unroll", | |
| "vars", | |
| "vectorize", | |
| "when", | |
| ] |
| self.root_block: List[Node] = [] | ||
| # Stack of active blocks (lists of nodes) where new nodes are appended | ||
| self.block_stack: List[List[Node]] = [self.root_block] | ||
| # Inputs to the kernel |
There was a problem hiding this comment.
The inputs attribute is a list. In caten/kernel.py, there's a check if arg.node not in builder.inputs which is an O(n) operation. For kernels with many input tensors, this could be a performance bottleneck during tracing. Using a dictionary would provide O(1) lookups while preserving insertion order (in Python 3.7+).
This would require updating register_input to use self.inputs[node] = True, reset to use self.inputs = {}, and the compile method in ClangRuntime to use list(builder.inputs).
| # Inputs to the kernel | |
| self.inputs: Dict[Node, bool] = {} |
| ... | ||
|
|
||
| # Execute in the complement domain of the previous 'when' | ||
| with C.otherwise(): |
There was a problem hiding this comment.
The design document mentions with C.otherwise(): and C.alloc() (line 176), but these features do not appear to be implemented in this pull request. This could be confusing for users reading the documentation. It would be clearer to either remove these for now or mark them explicitly as "planned" or "future work".
Summary of ChangesHello @hikettei, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request represents a substantial architectural overhaul of the Caten compiler. It establishes a robust frontend Domain-Specific Language (DSL) for defining tensor computations, backed by a comprehensive Intermediate Representation (IR) built from Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces a significant and well-structured rewrite of the core IR and compilation pipeline. The introduction of a tracing-based frontend, a multi-level IR, and integration with polyhedral scheduling via ISL is a massive step forward. The code is generally clean and demonstrates a strong architectural vision. My review focuses on several key areas: strengthening the tracing mechanism to handle outputs and keyword arguments, improving robustness in the pattern matcher and exception handling, reducing code duplication, and aligning the documentation with the implementation. There are a few critical and high-severity issues related to the tracing design and documentation that should be addressed, but overall this is an impressive and promising contribution.
| _ = func(*func_args) | ||
|
|
||
| # 4. Finalize Graph | ||
| full_graph = builder.root_block |
There was a problem hiding this comment.
The current tracing mechanism does not identify the output nodes of the kernel. The return value of the traced function is discarded (_ = func(*func_args)), and the entire root_block is passed to the compiler. This prevents dead code elimination (via the unimplemented resolve_graph) and means many unnecessary intermediate nodes are processed.
A mechanism to designate outputs is needed. For example, the kernel function's return statement could be used to identify the output tensor(s). The wrapper would then capture this return value and use it to determine the actual output nodes for the graph.
| def _to_node(obj: Any) -> Node: | ||
| if isinstance(obj, Node): | ||
| return obj | ||
| return Node(MetaOps.CONST, (), arg=obj) |
There was a problem hiding this comment.
| global _range_counter | ||
| self.args = args | ||
| self.iter_sym = Symbol(f"i{_range_counter}") | ||
| self.directives: List[Directive] = [] | ||
| _range_counter += 1 |
There was a problem hiding this comment.
Using a global counter _range_counter for generating unique symbol names is not thread-safe. If kernel compilation were to be parallelized in the future, this could lead to race conditions and non-unique symbol names. Consider using a thread-local counter or passing a context object through the compilation pipeline to manage state like this.
| if args: | ||
| for arg in args: | ||
| if isinstance(arg, Tensor): | ||
| func_args.append(arg) | ||
| if arg.node.op == MetaOps.PLACEHOLDER: | ||
| if arg.node not in builder.inputs: | ||
| builder.register_input(arg.node) | ||
| else: | ||
| func_args.append(arg) |
There was a problem hiding this comment.
The logic for creating placeholder nodes for function arguments only handles positional arguments (args). If tensors are passed as keyword arguments, they will be missed, leading to incorrect tracing. The argument handling should be updated to correctly process both positional and keyword arguments, perhaps by using inspect.signature.bind to map all provided arguments to their corresponding parameters.
| * **Methods**: | ||
| * `optimize()`: Apply transformations (tiling, fusion). | ||
| * `lower()`: Convert to target-specific code string. | ||
|
|
||
| ### `caten.polyhedral.ScheduleTree` (Wrappers) | ||
| * **`Domain`**: Root of the tree, defines the iteration space. | ||
| * **`Band`**: Represents a loop nest. Handles tiling, permuting. | ||
| * **`Filter`**: Selects a subset of the domain (or specific statements) for the subtree. | ||
| * **`Sequence`**: Specifies sequential execution order of children. | ||
| * **`Mark`**: Attaches metadata (e.g., "SIMD", "Unroll"). | ||
|
|
||
| ### `caten.render.Renderer` | ||
| * **Responsibility**: Visiting the AST/IR and emitting string code. | ||
| * **Subclasses**: `CRenderer`, `CUDARenderer`, `MetalRenderer`. | ||
| * **Input**: `isl_ast_node` (from Polyhedral) + `ops` graph. | ||
| * **Output**: Source code string. | ||
|
|
||
| ### `caten.runtime.Runtime` | ||
| * **Responsibility**: Hardware abstraction layer. | ||
| * `lower(target="c")`: Returns target-specific source code string. | ||
| * `compile()`: JIT compiles and returns a callable Python executable. |
There was a problem hiding this comment.
The documentation for caten.kernel.Kernel lists lower() and compile() as methods. However, the implementation in caten/kernel.py does not have these methods on the Kernel class. The compilation logic is handled within the @C.kernel decorator. The documentation should be updated to accurately reflect the class's API and the role of the decorator.
| if isinstance(arg, Tensor): | ||
| func_args.append(arg) | ||
| if arg.node.op == MetaOps.PLACEHOLDER: | ||
| if arg.node not in builder.inputs: |
There was a problem hiding this comment.
| except Exception as e: | ||
| print(f"WARNING: Failed to insert partial schedule: {e} for {mupa_str}") |
There was a problem hiding this comment.
Catching a generic Exception can hide bugs and make debugging difficult. It's better to catch more specific exceptions that you expect isl to raise (e.g., I.Error). Also, printing a warning to stdout is not ideal for a library. Consider using Python's logging module, which allows users of the library to control the verbosity and destination of such messages.
| # Note: AstToGraphConverter does NOT preserve original directives attached to range(), | ||
| # because they are lost when converting to ISL AST (unless marked). | ||
| # To support directives with Polyhedral model, we need to add marks in Schedule tree. | ||
|
|
||
| # For now, if we want directives to appear, we rely on the fact that user manually adds them via schedule API, | ||
| # OR we implement a mechanism to carry them over. | ||
| # The user request "with (C.range(10) | C.parallel())" implies they want it in the final code. | ||
| # Since we reconstruct graph from ISL AST, these directives are currently LOST. | ||
|
|
||
| # To fix this: | ||
| # We need to associate directives with the statement or loop in SCoP construction, | ||
| # and then re-apply them during scheduling or rendering. | ||
| # ISL supports 'mark' nodes. We can insert marks for directives. | ||
|
|
||
| # HOWEVER, for this turn, I'll just implement the syntax support and rendering capability. | ||
| # Connecting them through ISL requires deeper changes (inserting marks in schedule). | ||
|
|
||
| # Wait, if I use "Polyhedral Generated Kernel", I'm going through ISL. | ||
| # If I want to demonstrate directives, maybe I should skip ISL for a simple example? | ||
| # No, the requirement is strict about Polyhedral. | ||
|
|
||
| # I will leave the ISL integration part of directives as a limitation/TODO for now, | ||
| # as correct propagation requires AST generation callbacks or schedule tree manipulation. | ||
|
|
||
| # But to satisfy "PatternMatcher is not implemented", I prioritized that. | ||
|
|
||
| # Back to rendering: | ||
| # AstToGraphConverter uses the ISL AST. | ||
| # If we want directives, we need to modify PolyhedralSchedule to insert marks based on SCoP info. | ||
| # ScopStatementInfo needs to store directives? No, Range directives belong to loops, not statements directly. | ||
|
|
||
| # This is complex. I will implement the syntax and the renderer support. | ||
| # Propagation through ISL is out of scope for "PatternMatcher implementation" task? | ||
| # The user asked for "2. with (C.range(10) | C.parallel()) ... examples". | ||
|
|
There was a problem hiding this comment.
This block contains extensive comments explaining design limitations and future work (e.g., loss of directives through the ISL round-trip). While insightful, this kind of documentation is better suited for the design document (docs/IR_AND_DESIGN.md) or as TODO comments in the code. Keeping long, explanatory prose inside a method clutters the implementation.
| class ASTNodeType: | ||
| ERROR = -1 | ||
| FOR = 1 | ||
| IF = 2 | ||
| BLOCK = 3 | ||
| MARK = 4 | ||
| USER = 5 |
There was a problem hiding this comment.
The ASTNodeType class is defined with class attributes to represent constants. For better type safety and clarity, this should be defined as a proper enumeration using enum.IntEnum.
| class ASTNodeType: | |
| ERROR = -1 | |
| FOR = 1 | |
| IF = 2 | |
| BLOCK = 3 | |
| MARK = 4 | |
| USER = 5 | |
| from enum import IntEnum | |
| class ASTNodeType(IntEnum): | |
| ERROR = -1 | |
| FOR = 1 | |
| IF = 2 | |
| BLOCK = 3 | |
| MARK = 4 | |
| USER = 5 |
| __all__ = [ | ||
| "vars", | ||
| "range", | ||
| "kernel", | ||
| "Tensor", | ||
| "TensorSpec", | ||
| "float32", | ||
| "int32", | ||
| "f32", | ||
| "i32", | ||
| "DType", | ||
| "when", | ||
| "parallel", | ||
| "vectorize", | ||
| "unroll", | ||
| ] |
There was a problem hiding this comment.
For better maintainability and easier navigation, it's a good practice to keep the __all__ list sorted alphabetically.
| __all__ = [ | |
| "vars", | |
| "range", | |
| "kernel", | |
| "Tensor", | |
| "TensorSpec", | |
| "float32", | |
| "int32", | |
| "f32", | |
| "i32", | |
| "DType", | |
| "when", | |
| "parallel", | |
| "vectorize", | |
| "unroll", | |
| ] | |
| __all__ = [ | |
| "DType", | |
| "Tensor", | |
| "TensorSpec", | |
| "f32", | |
| "float32", | |
| "i32", | |
| "int32", | |
| "kernel", | |
| "parallel", | |
| "range", | |
| "unroll", | |
| "vars", | |
| "vectorize", | |
| "when", | |
| ] |
|
import caten as C ↑だけで全てのAPIが使えるべき、テストやサンプルコードもそれに従うべき |
|
グラフは常にDAG リストを作るな |
|
get_kernelはデコレーターではなく、wrapped functionへ、、、 |
No description provided.