-
Notifications
You must be signed in to change notification settings - Fork 69
[optimize-dot-operands]: Fuse load and trans operations - part 3 #4537
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: Tiotto, Ettore <[email protected]>
Signed-off-by: Tiotto, Ettore <[email protected]>
Signed-off-by: Tiotto, Ettore <[email protected]>
Signed-off-by: Tiotto, Ettore <[email protected]>
Signed-off-by: Tiotto, Ettore <[email protected]>
Signed-off-by: Tiotto, Ettore <[email protected]>
Signed-off-by: Tiotto, Ettore <[email protected]>
…tt.dot Signed-off-by: Tiotto, Ettore <[email protected]>
Signed-off-by: Tiotto, Ettore <[email protected]>
Signed-off-by: Tiotto, Ettore <[email protected]>
Signed-off-by: Tiotto, Ettore <[email protected]>
Depends on: #4468 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR enhances dot operands optimization by fusing load and transpose operations in separate loops when the def‑use chains originate from a make_tensor_ptr, and by refactoring cleanup routines.
- Added a new optimization pass (optimize_dot_operands) in multiple backend components.
- Introduced a new eraseOperations utility and refactored fusion logic in OptimizeDotOperands.cpp.
- Updated test cases to validate proper fusion and non‐fused behavior.
Reviewed Changes
Copilot reviewed 7 out of 7 changed files in this pull request and generated no comments.
Show a summary per file
File | Description |
---|---|
third_party/intel/triton_xpu.cc | Added optimize_dot_operands pass registration. |
third_party/intel/lib/Utils/Utility.cpp | Added a new eraseOperations function for cleanup operations. |
third_party/intel/lib/TritonIntelGPUTransforms/OptimizeDotOperands.cpp | Refactored fusion logic and propagation routines to support optimized chaining. |
third_party/intel/lib/Dialect/Triton/Transforms/TensorDescToBlockPointer.cpp | Removed redundant finalize() in favor of using eraseOperations. |
third_party/intel/include/Utils/Utility.h | Declared the new eraseOperations function. |
third_party/intel/backend/compiler.py | Registered the new optimize_dot_operands pass in the compiler backend. |
test/TritonIntelGPU/dot-operands.mlir | Updated test cases to reflect changes in fusion behavior and new pass functionality. |
Comments suppressed due to low confidence (2)
third_party/intel/lib/TritonIntelGPUTransforms/OptimizeDotOperands.cpp:161
- [nitpick] The singleUsersInChain function is quite complex; consider refactoring the logic or adding more inline comments to improve readability and maintainability.
// Determine whether all operations in the def-use chain from \p start to
third_party/intel/lib/TritonIntelGPUTransforms/OptimizeDotOperands.cpp:112
- [nitpick] Consider renaming the lambda 'usedByDotOp' to a more descriptive name such as 'isChainedToDotOp' to clarify its purpose.
auto usedByDotOp = [](tt::TransOp transOp) {
…tt.dot Signed-off-by: Tiotto, Ettore <[email protected]>
Signed-off-by: Tiotto, Ettore <[email protected]>
Signed-off-by: Tiotto, Ettore <[email protected]>
@@ -68,7 +71,10 @@ def _attn_fwd_inner(acc, l_i, m_i, q, # | |||
for start_n in tl.range(lo, hi, BLOCK_N, warp_specialize=warp_specialize): | |||
start_n = tl.multiple_of(start_n, BLOCK_N) | |||
# -- compute qk ---- | |||
k = desc_k.load([0, offsetk_y]) | |||
if dtype == tl.float8e5: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For fp16 we undo the source code changes we made and the code is now back to the original. For FP8 we keep the source code changes until we can issue DPAS instructions for them (after making 2 fp8 elems into a fp16).
@@ -80,15 +79,6 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32} { | |||
%c1_i64 = arith.constant 1 : i64 | |||
%c1024_i64 = arith.constant 1024 : i64 | |||
%cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #mma> | |||
%0 = tt.get_program_id x : i32 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just making the test simpler here
ping @whitneywhtsang, @chengjunlu, @LiyangLingIntel any comments ? |
third_party/intel/lib/TritonIntelGPUTransforms/OptimizeDotOperands.cpp
Outdated
Show resolved
Hide resolved
third_party/intel/lib/TritonIntelGPUTransforms/OptimizeDotOperands.cpp
Outdated
Show resolved
Hide resolved
third_party/intel/lib/TritonIntelGPUTransforms/OptimizeDotOperands.cpp
Outdated
Show resolved
Hide resolved
third_party/intel/lib/TritonIntelGPUTransforms/OptimizeDotOperands.cpp
Outdated
Show resolved
Hide resolved
third_party/intel/lib/TritonIntelGPUTransforms/OptimizeDotOperands.cpp
Outdated
Show resolved
Hide resolved
|
||
// Prune candidate chains containing load/trans operations that cannot be | ||
// safely fused. | ||
prune(chains); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you think is worth pruning rootToChains
with chains that contain at least one candidate first? That way we won't clone the chain if there will be no candidates.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another thought is to have a flag to indicate if we want to clone for a particular root in rootToChains
,
if no candidate or all candidates, then no need to clone.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure I fully understand the question. We collect def-use chains terminated at a TransOp
is that operation is a candidate (therefore not all TransOp are collected into a chain to start with). Then, if there is only one chain, no cloning is necessary. If there are 2 or more chains, we clone the root operation only if that operation is the start operation of more than one chain. After that we prune chains if we detect that operations in the "middle" of the chain have more than one user. The remaining chains are the final candidates for fusion.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After discussing offline I understand the suggestion. We agreed to improve the implementation in the next PR.
Thanks @whitneywhtsang for the prompt review! |
Signed-off-by: Tiotto, Ettore <[email protected]>
Addresses first round of comments. I still have some comments to work on @whitneywhtsang. |
third_party/intel/lib/TritonIntelGPUTransforms/OptimizeDotOperands.cpp
Outdated
Show resolved
Hide resolved
Signed-off-by: Tiotto, Ettore <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should be fine in this PR as it is limited to one user per operations in chain, but in general need to be careful that there can be more than one chain with the same start and same end.
third_party/intel/lib/TritonIntelGPUTransforms/OptimizeDotOperands.cpp
Outdated
Show resolved
Hide resolved
Signed-off-by: Tiotto, Ettore <[email protected]>
Enhance the transformation to allow multiple
load+transpose
fusion opportunities in separate for loops when the def-use chains corresponding to the opportunities originate at the samemake_tensor_ptr
operation.