-
Notifications
You must be signed in to change notification settings - Fork 79
Update structured-to-memref pass to support the new pass pipeline
#217
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
178f046 to
8090fcd
Compare
structured-to-memref pass to support the new pass pipeline
This was referenced Jan 10, 2025
nhat-nguyen
added a commit
that referenced
this pull request
Jan 14, 2025
This PR introduces the `triton-to-unstructured` pass which is the first
step towards allowing triton-shared to compile pointer sequences that
cannot be analyzed by `triton-to-structured` (gather / scatter).
This pass attempts to lower all loads and stores of unstructured
pointers to
tts.gather or tts.scatter that take a single base, a tensor of offsets,
an
optional tensor of mask values, and a default value in case of load.
In addition, all pointer-producing ops will be eliminated and replaced
by
offset-producing ops. tts.gather and tts.scatter will use the pointer
directly from the kernel arguments as opposed to pointer produced by ops
such
as tt.addptr and tt.splat.
Example:
```mlir
module {
tt.func public @gather_simple_no_loop(%arg0: !tt.ptr<f32>, %arg1: !tt.ptr<f32>) attributes {noinline = false} {
%cst = arith.constant dense<5> : tensor<64xi32>
%cst_0 = arith.constant dense<10> : tensor<64xi32>
%0 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
%1 = arith.divsi %0, %cst_0 : tensor<64xi32>
%2 = arith.addi %1, %cst : tensor<64xi32>
%3 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<64x!tt.ptr<f32>>
%4 = tt.addptr %3, %2 : tensor<64x!tt.ptr<f32>>, tensor<64xi32>
%5 = tt.load %4 : tensor<64x!tt.ptr<f32>>
%6 = tt.splat %arg1 : !tt.ptr<f32> -> tensor<64x!tt.ptr<f32>>
%7 = tt.addptr %6, %0 : tensor<64x!tt.ptr<f32>>, tensor<64xi32>
tt.store %7, %5 : tensor<64x!tt.ptr<f32>>
tt.return
}
}
```
becomes
```mlir
module {
tt.func public @gather_simple_no_loop(%arg0: !tt.ptr<f32>, %arg1: !tt.ptr<f32>) attributes {noinline = false} {
%cst = arith.constant dense<5> : tensor<64xi32>
%cst_0 = arith.constant dense<10> : tensor<64xi32>
%0 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
%1 = arith.divsi %0, %cst_0 : tensor<64xi32>
%2 = arith.addi %1, %cst : tensor<64xi32>
%3 = tts.gather %arg0[%2] : (<f32>, tensor<64xi32>) -> tensor<64xf32>
tts.scatter %3 into %arg1[%0] : tensor<64xf32> into (<f32>, tensor<64xi32>)
tt.return
}
}
```
Current assumptions and limitations:
- For simplicity, the pass assumes that gather / scatter operations load
/
store from / to a single base with a tensor of random offsets. As a
result, the following triton program would not work:
```python
@triton.jit
def gather_simple(in0, in1, out0):
offs = tl.arange(0, 8)
in0_ptrs = in0 + offs
in1_ptrs = in1 + offs
ptrs = tl.cat(in0_ptrs, in1_ptrs, can_reorder=True)
c = tl.load(ptrs)
out_offs = tl.arange(0, 16)
tl.store(out0 + out_offs, c)
```
In the above program, `ptrs` contains 2 bases: `in0` and `in1` after the
`cat` operation.
For more details on the algorithm, see the
`TritonToUnstructuredPass.cpp` file.
# Future work
Future work may include scaling the algorithm to support multiple bases
-- one
possible solution is to let tts.gather and tts.scatter take in an
additional
tensor of base pointers corresponding to the tensor of offsets. But
because
we do not want pointer-producing ops to be present after this pass, we
can
use a tensor of index where each element indicates the index of the
pointer
argument to be used. The drawback is a gather or scatter operation now
needs
one extract lookup to get the base which will affect performance.
---
# Intended lowering pipeline
- triton-to-structured (no changes):
- analyzes structured addptr sequences
- introduces `tts.make_tptr %ptr_arg with offsets and strides`
- introduces `tts.load` and `tts.store`
- leaves unstructured addptr sequences and their corresponding `tt.load`
and `tt.store` intact
- triton-to-unstructured (#210):
- introduces `tts.gather` and `tts.scatter`
- removes all pointer-producing ops such as `tt.addptr` and `tt.splat`
and replaces them with offset-producing ops
- structured-to-memref (#217):
- currently converts everything to memref including scalar addptr and
kernel arguments
- will change to just convert ops in the `tts` dialect to `memref` with
the exception of `tts.gather` and `tts.scatter`
- unstructured-to-memref (#216):
- converts the remaining unstructured `tts.gather`, `tts.scatter` into
memref
- triton-ptr-to-memref (#211):
- converts kernel arguments with pointer type to memref
nhat-nguyen
added a commit
that referenced
this pull request
Jan 14, 2025
This PR introduces the `triton-ptr-to-memref` pass responsible for
converting function signature that uses triton ptr to use memref
instead. This is part of the work to allow triton-shared to lower gather
/ scatter pointer sequences.
Much of this code is copied from the current `StructuredToMemref` pass
which will be cleaned up in a later PR.
---
# Intended lowering pipeline
- triton-to-structured (no changes):
- analyzes structured addptr sequences
- introduces `tts.make_tptr %ptr_arg with offsets and strides`
- introduces `tts.load` and `tts.store`
- leaves unstructured addptr sequences and their corresponding `tt.load`
and `tt.store` intact
- triton-to-unstructured (#210):
- introduces `tts.gather` and `tts.scatter`
- removes all pointer-producing ops such as `tt.addptr` and `tt.splat`
and replaces them with offset-producing ops
- structured-to-memref (#217):
- currently converts everything to memref including scalar addptr and
kernel arguments
- will change to just convert ops in the `tts` dialect to `memref` with
the exception of `tts.gather` and `tts.scatter`
- unstructured-to-memref (#216):
- converts the remaining unstructured `tts.gather`, `tts.scatter` into
memref
- triton-ptr-to-memref (#211):
- converts kernel arguments with pointer type to memref
nhat-nguyen
added a commit
that referenced
this pull request
Jan 15, 2025
This PR introduces the `unstructured-to-memref` pass responsible for
converting unstructured triton load / store ops to memref load / store
ops. This is part of the work to allow triton-shared to lower gather /
scatter pointer sequences. The pass is intended to be used after running
`--fold-unstructured-ptr`.
Triton load op (gather) is lowered to a `linalg.generic` whose body
contains a load from the offset indicated by the offset provided by
`tts.make_unstructured_tptr`. For load op with mask, an inner-most
`scf.if` is used to return a default value (or the `other` in `tt.load`
if provided) if the corresponding mask value is false.
Example of a load:
```mlir
func.func @gather_simple_mask_with_other(%arg0: memref<*xf32>, %arg1: memref<*xf32>) {
%cst = arith.constant -1.000000e+00 : f32
%cast = memref.cast %arg0 : memref<*xf32> to memref<?xf32>
%load_tensor = bufferization.to_tensor %cast restrict : memref<?xf32>
%out = tensor.empty() : tensor<64xf32>
%gather = linalg.generic {
iterator_types = ["parallel"]
} ins(%offset_tensor, %mask_tensor : tensor<64xi32>, tensor<64xi1>)
outs(%out : tensor<64xf32>) {
^bb0(%offset: i32, %mask: i1, %out: f32):
%yield = scf.if %mask -> (f32) {
%index = arith.index_cast %offset : i32 to index
%extracted = tensor.extract %load_tensor[%index] : tensor<?xf32>
scf.yield %extracted : f32
} else {
scf.yield %cst : f32
}
linalg.yield %yield : f32
} -> tensor<64xf32>
```
Triton store op (scatter) is lowered to an `affine.for` loop nest that
stores the value to the appropriate offset provided by
`tts.make_unstructured_tptr`. Store op with mask is also supported.
Example of a store:
```mlir
func.func @masked_gather_scatter(%arg0: memref<*xf32>, %arg1: memref<*xf32>) {
%store_memref = memref.cast %arg1 : memref<*xf32> to memref<?xf32>
affine.for %i = 0 to 4 {
%mask_val = tensor.extract %mask[%i] : tensor<4xi1>
scf.if %mask_val {
%offset_val = tensor.extract %offset_tensor[%i] : tensor<4xi32>
%store_value = tensor.extract %tensor[%i] : tensor<4xf32>
%offset_index = arith.index_cast %offset_val : i32 to index
memref.store %store_value, %store_memref[%offset_index] : memref<?xf32>
}
}
```
---
# Intended lowering pipeline
- triton-to-structured (no changes):
- analyzes structured addptr sequences
- introduces `tts.make_tptr %ptr_arg with offsets and strides`
- introduces `tts.load` and `tts.store`
- leaves unstructured addptr sequences and their corresponding `tt.load`
and `tt.store` intact
- triton-to-unstructured (#210):
- introduces `tts.gather` and `tts.scatter`
- removes all pointer-producing ops such as `tt.addptr` and `tt.splat`
and replaces them with offset-producing ops
- structured-to-memref (#217):
- currently converts everything to memref including scalar addptr and
kernel arguments
- will change to just convert ops in the `tts` dialect to `memref` with
the exception of `tts.gather` and `tts.scatter`
- unstructured-to-memref (#216):
- converts the remaining unstructured `tts.gather`, `tts.scatter` into
memref
- triton-ptr-to-memref (#211):
- converts kernel arguments with pointer type to memref
30fa7d8 to
f9fcc9f
Compare
beicy
approved these changes
Jan 16, 2025
beicy
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This PR simplifies the
structured-to-memrefpass responsible for converting structured triton load / store ops to memref load / store ops. This is part of the work to allow triton-shared to lower gather / scatter pointer sequences. Previously, this pass is also responsible for converting scalar pointer load and store into memref; that transformation has now been moved tounstructured-to-memref.In addition, the PR also updates the
triton-to-linalg-experimentalpass to fully utilize all the new passes. Once merged, triton-shared now fully supports gather / scatter. An example test (test_gather_scatter.py) is also added to demonstrate this new capability.Intended lowering pipeline
tts.make_tptr %ptr_arg with offsets and stridestts.loadandtts.storett.loadandtt.storeintacttriton-to-unstructuredpass #210):tts.gatherandtts.scattertt.addptrandtt.splatand replaces them with offset-producing opsstructured-to-memrefpass to support the new pass pipeline #217):ttsdialect tomemrefwith the exception oftts.gatherandtts.scatterunstructured-to-memrefpass #216):tts.gather,tts.scatterinto memreftriton-ptr-to-memrefpass #211):