-
Notifications
You must be signed in to change notification settings - Fork 76
Use block load attribute to remove duplicate logic from MaterializeBlockPointer pass #2420
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Use block load attribute to remove duplicate logic from MaterializeBlockPointer pass #2420
Conversation
|
After looking at the test failure I have realized this approach won't work. I did not realize we needed to mark all TensorPtr related ops for removal when walking the module. Our pass follows the upstream pass pattern, but in upstream they remove all TensorPtr related ops regardless of how they are used - we want to propagate that use info. I wrote a proof of concept version of this PR which essentially does a DFS on each TensorPtr op to find a load/store op to see if we need to remove or keep the TensorPtr. That requires two additional data structures though - one to keep the working set of operations, and one to keep the set of values to remove. Not a big deal, but messier than I would like. This seems to work but still has some bugs. The other approach would be to write an analysis pass and tag each op with TensorPtr type for removal based on the analysis -- theoretically we can keep more complex state, but it does seem like the result might be a similar process as I have implemented. Is one approach preferred over another, especially since we may remove this pass in the future? |
4683fa6 to
365d43e
Compare
|
Here's the new, working algorithm: First, walk the tree and look for Next, we make a pass through all the ops again and make sure we removal all I also had to move the fp8 column major block load logic to |
d939995 to
546e577
Compare
546e577 to
ccb3953
Compare
third_party/intel/lib/TritonIntelGPUTransforms/RewriteTensorPointer.cpp
Outdated
Show resolved
Hide resolved
third_party/intel/lib/TritonIntelGPUTransforms/RewriteTensorPointer.cpp
Outdated
Show resolved
Hide resolved
| } | ||
| } | ||
| }; | ||
| workingSet.erase(crtOpItr); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
any chance of inserting elements erased before? do we need a visited set?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think so - we would need an instance where the result from an op is used by an op lower down in the call graph, but that returns its result to an op higher in the call graph. Maybe if we had conditional branching, but I am not sure how that would be handled.
I might be able to refactor this to eliminate the workingSet, but for now I added a contrived lit test that has tensor ptrs but no loads or stores. This is the scenario that I think is most likely to get into an infinite loop, and it passes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suspect the following LLVMIR code would run into an infinite loop
loop:
%0 = phi [%1, loop]
%1 = add %0, 1
br cond, loop, exit
so in MLIR, maybe something like
%0 = maketensorptr
%2:1 = scf.for %iv = %c0_i32 to %c5120_i32 step %c32_i32 iter_args(%arg = %0) {
%1 = add %arg, 1
scf.yield %1
}
|
That last commit should instead say |
third_party/intel/lib/TritonIntelGPUTransforms/RewriteTensorPointer.cpp
Outdated
Show resolved
Hide resolved
third_party/intel/lib/TritonIntelGPUTransforms/RewriteTensorPointer.cpp
Outdated
Show resolved
Hide resolved
| } | ||
| } | ||
| }; | ||
| workingSet.erase(crtOpItr); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suspect the following LLVMIR code would run into an infinite loop
loop:
%0 = phi [%1, loop]
%1 = add %0, 1
br cond, loop, exit
so in MLIR, maybe something like
%0 = maketensorptr
%2:1 = scf.for %iv = %c0_i32 to %c5120_i32 step %c32_i32 iter_args(%arg = %0) {
%1 = add %arg, 1
scf.yield %1
}
|
I don't know why github won't let me reply directly to your comment, but isn't the same as the test I added, with add replace to
There might be cases this misses, but I think we should add them as we find them. Considering all possible cases upfront will be quite difficult, and note the old algorithm behaved similarly (only considered the inner body of |
…st on xpu" This reverts commit ccb3953. Does not work on LTS currently, as we are removing tensor pointers on LTS and producing different TTGIR.
30ed7d2 to
0a42f99
Compare
this caused the lambda in walk to be typed which means we need an explicit return everywhere This reverts commit 0a42f99.
Probably need to reply to the original thread which is higher, but doesn't matter.
True, I guess is fine as MLIR doesn't have PHINodes.
My concern is not if we generate good code, more on if we would ever run into an infinite loop, which is a behavior that we never want. Anyways, I cannot think of an example. |
We cannot lower a transposed A matrix to a transposed 2D block load. Instead, the load is lowered via the LLVM path introduced in #2181 . There appears to be a performance regression in this path which is slower than materializing the block in SLM and then reading into registers and computing the dot product from there. Using the work in #2420 I am able to drop the block load attribute for this case and go down the non block ptr path. Performance on main: ``` Compute A x B ✅ Triton and Torch match Time for torch: 0.32444801926612854 ms Time for triton: 0.44371041655540466 ms Compute A x B.T ✅ Triton and Torch match Time for torch: 0.32708799839019775 ms Time for triton: 0.634996771812439 ms Compute A.T x B ✅ Triton and Torch match Time for torch: 0.31204161047935486 ms Time for triton: 3.4140689373016357 ms Compute A.T x B.T ✅ Triton and Torch match Time for torch: 0.45701122283935547 ms Time for triton: 3.7463345527648926 ms ``` Performance on this PR: ``` Compute A x B ✅ Triton and Torch match Time for torch: 0.3081200122833252 ms Time for triton: 0.44333598017692566 ms Compute A x B.T ✅ Triton and Torch match Time for torch: 0.33799198269844055 ms Time for triton: 0.6391856074333191 ms Compute A.T x B ✅ Triton and Torch match Time for torch: 0.31700319051742554 ms Time for triton: 1.5733630657196045 ms Compute A.T x B.T ✅ Triton and Torch match Time for torch: 0.45083683729171753 ms Time for triton: 1.8271965980529785 ms ``` Note that the important commit is `31386ef1132c3f6cf9cb5f1063ecfab705f4c2a1`. Once #2420 is merged I will rebase this. Depends on #2420. Links to #1795.
The logic in
shouldRemovein theRewriteTensorPointerpass duplicates the same logic inMaterializeBlockPointer: https://github.com/intel/intel-xpu-backend-for-triton/blob/main/third_party/intel/lib/TritonIntelGPUTransforms/MaterializeBlockPointer.cpp#L50This duplication is necessary for the Matrix transpose multiplication case because the block pointer is defined outside a
scf.forloop, but the load is inside the loop. The previous logic inRewriteTensorPointercould not "see" into thescf.forloop block and decided to remove the tensor pointer even though its result was used by a block load. This PR changes the processing of ascf.forloop to associate the arguments to the for loop with any users inside the for loop body block. That change coupled with #2400 makes the duplicated logic not necessary for the matrix transpose case.But, the goal of #2380 is to remove the duplicated code. I believe we should be able to do this by checking the block attribute to the load, but I am wondering if there are other tensor pointers that we want to keep, even if the load/store is not tagged with the block attribute. For now I am only checking the block tag inside the scf section of RewriteTensorPointer, but I think we should consider expanding this check, and possibly removing the dpas layout checks inside RewriteTensorPointer. In other words, if we have a tensor pointer followed by a block load, we always keep the tensor pointer. @etiotto @whitneywhtsang @chengjunlu do other cases exist where we do not have a blocked load but want to keep the tensor pointer? One case I can think of is where we have a DPAS dot operation, a blocked load input, and an input that cannot be a block load (but could be materialize through scatter/gather vectorization per #2181). So perhaps the logic should be:
I am going to run this through CI and benchmarks and make sure there are no regressions while we discuss the above topic.
Close #2380