Skip to content

Conversation

@alexbaden
Copy link
Contributor

The logic in shouldRemove in the RewriteTensorPointer pass duplicates the same logic in MaterializeBlockPointer: https://github.com/intel/intel-xpu-backend-for-triton/blob/main/third_party/intel/lib/TritonIntelGPUTransforms/MaterializeBlockPointer.cpp#L50

This duplication is necessary for the Matrix transpose multiplication case because the block pointer is defined outside a scf.for loop, but the load is inside the loop. The previous logic in RewriteTensorPointer could not "see" into the scf.for loop block and decided to remove the tensor pointer even though its result was used by a block load. This PR changes the processing of a scf.for loop to associate the arguments to the for loop with any users inside the for loop body block. That change coupled with #2400 makes the duplicated logic not necessary for the matrix transpose case.

But, the goal of #2380 is to remove the duplicated code. I believe we should be able to do this by checking the block attribute to the load, but I am wondering if there are other tensor pointers that we want to keep, even if the load/store is not tagged with the block attribute. For now I am only checking the block tag inside the scf section of RewriteTensorPointer, but I think we should consider expanding this check, and possibly removing the dpas layout checks inside RewriteTensorPointer. In other words, if we have a tensor pointer followed by a block load, we always keep the tensor pointer. @etiotto @whitneywhtsang @chengjunlu do other cases exist where we do not have a blocked load but want to keep the tensor pointer? One case I can think of is where we have a DPAS dot operation, a blocked load input, and an input that cannot be a block load (but could be materialize through scatter/gather vectorization per #2181). So perhaps the logic should be:

check if tensor pointer is used by load/store
if yes and check if block attribute is on load/store -- keep tensor pointer 
if yes and no block attribute -- keep tensor pointer if tensor + load/store has DPAS layout 
if no -- remove tensor pointer 

I am going to run this through CI and benchmarks and make sure there are no regressions while we discuss the above topic.
 
Close #2380

@alexbaden
Copy link
Contributor Author

After looking at the test failure I have realized this approach won't work. I did not realize we needed to mark all TensorPtr related ops for removal when walking the module. Our pass follows the upstream pass pattern, but in upstream they remove all TensorPtr related ops regardless of how they are used - we want to propagate that use info.

I wrote a proof of concept version of this PR which essentially does a DFS on each TensorPtr op to find a load/store op to see if we need to remove or keep the TensorPtr. That requires two additional data structures though - one to keep the working set of operations, and one to keep the set of values to remove. Not a big deal, but messier than I would like. This seems to work but still has some bugs. The other approach would be to write an analysis pass and tag each op with TensorPtr type for removal based on the analysis -- theoretically we can keep more complex state, but it does seem like the result might be a similar process as I have implemented. Is one approach preferred over another, especially since we may remove this pass in the future?

@alexbaden alexbaden force-pushed the alex/use_block_load_attr_in_rewrite_tensor_pointer branch 2 times, most recently from 4683fa6 to 365d43e Compare October 8, 2024 01:41
@alexbaden alexbaden marked this pull request as ready for review October 8, 2024 01:43
@alexbaden
Copy link
Contributor Author

Here's the new, working algorithm:

First, walk the tree and look for MakeTensorPtr ops. For each MakeTensorPtr op, we do a DFS to find load/store users of the op. If we have a store associated with DPAS layout, or a block load, then we do not mark the MakeTensorPtr op for removal. Otherwise, we mark it for removal.

Next, we make a pass through all the ops again and make sure we removal all MakeTensorPtr-related ops for each MakeTensorPtr marked for removal (tt.advance, rewrite the loads, etc).

I also had to move the fp8 column major block load logic to MaterializeBlockPointer, but that is expected if the goal is to cleanup duplication.

@alexbaden alexbaden force-pushed the alex/use_block_load_attr_in_rewrite_tensor_pointer branch from 546e577 to ccb3953 Compare October 9, 2024 19:56
}
}
};
workingSet.erase(crtOpItr);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any chance of inserting elements erased before? do we need a visited set?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so - we would need an instance where the result from an op is used by an op lower down in the call graph, but that returns its result to an op higher in the call graph. Maybe if we had conditional branching, but I am not sure how that would be handled.
I might be able to refactor this to eliminate the workingSet, but for now I added a contrived lit test that has tensor ptrs but no loads or stores. This is the scenario that I think is most likely to get into an infinite loop, and it passes.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suspect the following LLVMIR code would run into an infinite loop

loop:
  %0 = phi [%1, loop]
  %1 = add %0, 1
  br cond, loop, exit

so in MLIR, maybe something like

%0 = maketensorptr
%2:1 = scf.for %iv = %c0_i32 to %c5120_i32 step %c32_i32 iter_args(%arg = %0) {
  %1 = add %arg, 1
  scf.yield %1
}

@alexbaden
Copy link
Contributor Author

That last commit should instead say remove stores too...

}
}
};
workingSet.erase(crtOpItr);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suspect the following LLVMIR code would run into an infinite loop

loop:
  %0 = phi [%1, loop]
  %1 = add %0, 1
  br cond, loop, exit

so in MLIR, maybe something like

%0 = maketensorptr
%2:1 = scf.for %iv = %c0_i32 to %c5120_i32 step %c32_i32 iter_args(%arg = %0) {
  %1 = add %arg, 1
  scf.yield %1
}

@alexbaden
Copy link
Contributor Author

I don't know why github won't let me reply directly to your comment, but isn't

%0 = maketensorptr
%2:1 = scf.for %iv = %c0_i32 to %c5120_i32 step %c32_i32 iter_args(%arg = %0) {
  %1 = add %arg, 1
  scf.yield %1
}

the same as the test I added, with add replace to tt.advance? tt.advance is handled generically, presumably add would be handled the same way. Here's the debug log output from the newly added case for the first make_tensor_ptr op:

[tritonintelgpu-rewrite-tensor-pointer]: Considering: %10 = tt.make_tensor_ptr %arg0, [%c1024_i64, %c5120_i64], [%c5120_i64, %c1_i64], [%9, %c0_i32] {order = array<i32: 1, 0>} : <tensor<256x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}>, kWidth = 2}>>>
[tritonintelgpu-rewrite-tensor-pointer]: Processing op: %13:3 = scf.for %arg3 = %c0_i32 to %c5120_i32 step %c32_i32 iter_args(%arg4 = %cst, %arg5 = %10, %arg6 = %12) -> (tensor<256x256xf32, #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}>>, !tt.ptr<tensor<256x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}>, kWidth = 2}>>>, !tt.ptr<tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}>, kWidth = 2}>>>)  : i32 {
  %16 = tt.advance %arg5, [%c0_i32, %c32_i32] : <tensor<256x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}>, kWidth = 2}>>>
  %17 = tt.advance %arg6, [%c32_i32, %c0_i32] : <tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}>, kWidth = 2}>>>
  scf.yield %arg4, %16, %17 : tensor<256x256xf32, #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}>>, !tt.ptr<tensor<256x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}>, kWidth = 2}>>>, !tt.ptr<tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}>, kWidth = 2}>>>
}
[tritonintelgpu-rewrite-tensor-pointer]: Processing op: %16 = tt.advance %arg5, [%c0_i32, %c32_i32] : <tensor<256x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}>, kWidth = 2}>>>
[tritonintelgpu-rewrite-tensor-pointer]: Not a load store and not a loop, adding users to working set.
[tritonintelgpu-rewrite-tensor-pointer]: Processing op: scf.yield %arg4, %16, %17 : tensor<256x256xf32, #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}>>, !tt.ptr<tensor<256x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}>, kWidth = 2}>>>, !tt.ptr<tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}>, kWidth = 2}>>>

scf.yield has no users so we stop the iteration there.

There might be cases this misses, but I think we should add them as we find them. Considering all possible cases upfront will be quite difficult, and note the old algorithm behaved similarly (only considered the inner body of scf.for loop and the yield operands).

@alexbaden alexbaden force-pushed the alex/use_block_load_attr_in_rewrite_tensor_pointer branch from 30ed7d2 to 0a42f99 Compare October 10, 2024 01:11
this caused the lambda in walk to be typed which means we need an explicit return everywhere

This reverts commit 0a42f99.
@whitneywhtsang
Copy link
Contributor

I don't know why github won't let me reply directly to your comment

Probably need to reply to the original thread which is higher, but doesn't matter.

scf.yield has no users so we stop the iteration there.

True, I guess is fine as MLIR doesn't have PHINodes.

There might be cases this misses, but I think we should add them as we find them. Considering all possible cases upfront will be quite difficult, and note the old algorithm behaved similarly (only considered the inner body of scf.for loop and the yield operands).

My concern is not if we generate good code, more on if we would ever run into an infinite loop, which is a behavior that we never want. Anyways, I cannot think of an example.

@alexbaden alexbaden merged commit 734e33c into main Oct 10, 2024
@alexbaden alexbaden deleted the alex/use_block_load_attr_in_rewrite_tensor_pointer branch October 10, 2024 11:18
alexbaden added a commit that referenced this pull request Oct 10, 2024
We cannot lower a transposed A matrix to a transposed 2D block load.
Instead, the load is lowered via the LLVM path introduced in #2181 .
There appears to be a performance regression in this path which is
slower than materializing the block in SLM and then reading into
registers and computing the dot product from there. Using the work in
#2420 I am able to drop the block load attribute for this case and go
down the non block ptr path.

Performance on main:
```
Compute A x B
✅ Triton and Torch match
Time for torch: 0.32444801926612854 ms
Time for triton: 0.44371041655540466 ms
Compute A x B.T
✅ Triton and Torch match
Time for torch: 0.32708799839019775 ms
Time for triton: 0.634996771812439 ms
Compute A.T x B
✅ Triton and Torch match
Time for torch: 0.31204161047935486 ms
Time for triton: 3.4140689373016357 ms
Compute A.T x B.T
✅ Triton and Torch match
Time for torch: 0.45701122283935547 ms
Time for triton: 3.7463345527648926 ms
```

Performance on this PR:
```
Compute A x B
✅ Triton and Torch match
Time for torch: 0.3081200122833252 ms
Time for triton: 0.44333598017692566 ms
Compute A x B.T
✅ Triton and Torch match
Time for torch: 0.33799198269844055 ms
Time for triton: 0.6391856074333191 ms
Compute A.T x B
✅ Triton and Torch match
Time for torch: 0.31700319051742554 ms
Time for triton: 1.5733630657196045 ms
Compute A.T x B.T
✅ Triton and Torch match
Time for torch: 0.45083683729171753 ms
Time for triton: 1.8271965980529785 ms
```

Note that the important commit is
`31386ef1132c3f6cf9cb5f1063ecfab705f4c2a1`. Once #2420 is merged I will
rebase this.

Depends on #2420. Links to #1795.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Clean up duplication between MaterializeBlockPtr and RewriteTensorPointer

3 participants