Commit e9262ad
authored
Introduce
This PR introduces the `unstructured-to-memref` pass responsible for
converting unstructured triton load / store ops to memref load / store
ops. This is part of the work to allow triton-shared to lower gather /
scatter pointer sequences. The pass is intended to be used after running
`--fold-unstructured-ptr`.
Triton load op (gather) is lowered to a `linalg.generic` whose body
contains a load from the offset indicated by the offset provided by
`tts.make_unstructured_tptr`. For load op with mask, an inner-most
`scf.if` is used to return a default value (or the `other` in `tt.load`
if provided) if the corresponding mask value is false.
Example of a load:
```mlir
func.func @gather_simple_mask_with_other(%arg0: memref<*xf32>, %arg1: memref<*xf32>) {
%cst = arith.constant -1.000000e+00 : f32
%cast = memref.cast %arg0 : memref<*xf32> to memref<?xf32>
%load_tensor = bufferization.to_tensor %cast restrict : memref<?xf32>
%out = tensor.empty() : tensor<64xf32>
%gather = linalg.generic {
iterator_types = ["parallel"]
} ins(%offset_tensor, %mask_tensor : tensor<64xi32>, tensor<64xi1>)
outs(%out : tensor<64xf32>) {
^bb0(%offset: i32, %mask: i1, %out: f32):
%yield = scf.if %mask -> (f32) {
%index = arith.index_cast %offset : i32 to index
%extracted = tensor.extract %load_tensor[%index] : tensor<?xf32>
scf.yield %extracted : f32
} else {
scf.yield %cst : f32
}
linalg.yield %yield : f32
} -> tensor<64xf32>
```
Triton store op (scatter) is lowered to an `affine.for` loop nest that
stores the value to the appropriate offset provided by
`tts.make_unstructured_tptr`. Store op with mask is also supported.
Example of a store:
```mlir
func.func @masked_gather_scatter(%arg0: memref<*xf32>, %arg1: memref<*xf32>) {
%store_memref = memref.cast %arg1 : memref<*xf32> to memref<?xf32>
affine.for %i = 0 to 4 {
%mask_val = tensor.extract %mask[%i] : tensor<4xi1>
scf.if %mask_val {
%offset_val = tensor.extract %offset_tensor[%i] : tensor<4xi32>
%store_value = tensor.extract %tensor[%i] : tensor<4xf32>
%offset_index = arith.index_cast %offset_val : i32 to index
memref.store %store_value, %store_memref[%offset_index] : memref<?xf32>
}
}
```
---
# Intended lowering pipeline
- triton-to-structured (no changes):
- analyzes structured addptr sequences
- introduces `tts.make_tptr %ptr_arg with offsets and strides`
- introduces `tts.load` and `tts.store`
- leaves unstructured addptr sequences and their corresponding `tt.load`
and `tt.store` intact
- triton-to-unstructured (#210):
- introduces `tts.gather` and `tts.scatter`
- removes all pointer-producing ops such as `tt.addptr` and `tt.splat`
and replaces them with offset-producing ops
- structured-to-memref (#217):
- currently converts everything to memref including scalar addptr and
kernel arguments
- will change to just convert ops in the `tts` dialect to `memref` with
the exception of `tts.gather` and `tts.scatter`
- unstructured-to-memref (#216):
- converts the remaining unstructured `tts.gather`, `tts.scatter` into
memref
- triton-ptr-to-memref (#211):
- converts kernel arguments with pointer type to memrefunstructured-to-memref pass (#216)1 parent 91ac8d8 commit e9262ad
File tree
13 files changed
+824
-1
lines changed- lib/Conversion
- UnstructuredToMemref
- test/Conversion/UnstructuredToMemref
- tools
13 files changed
+824
-1
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
2 | 2 | | |
3 | 3 | | |
4 | 4 | | |
| 5 | + | |
5 | 6 | | |
6 | 7 | | |
7 | | - | |
| 8 | + | |
Lines changed: 10 additions & 0 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
| 1 | + | |
| 2 | + | |
| 3 | + | |
| 4 | + | |
| 5 | + | |
| 6 | + | |
| 7 | + | |
| 8 | + | |
| 9 | + | |
| 10 | + | |
Lines changed: 22 additions & 0 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
| 1 | + | |
| 2 | + | |
| 3 | + | |
| 4 | + | |
| 5 | + | |
| 6 | + | |
| 7 | + | |
| 8 | + | |
| 9 | + | |
| 10 | + | |
| 11 | + | |
| 12 | + | |
| 13 | + | |
| 14 | + | |
| 15 | + | |
| 16 | + | |
| 17 | + | |
| 18 | + | |
| 19 | + | |
| 20 | + | |
| 21 | + | |
| 22 | + | |
Lines changed: 18 additions & 0 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
| 1 | + | |
| 2 | + | |
| 3 | + | |
| 4 | + | |
| 5 | + | |
| 6 | + | |
| 7 | + | |
| 8 | + | |
| 9 | + | |
| 10 | + | |
| 11 | + | |
| 12 | + | |
| 13 | + | |
| 14 | + | |
| 15 | + | |
| 16 | + | |
| 17 | + | |
| 18 | + | |
Lines changed: 21 additions & 0 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
| 1 | + | |
| 2 | + | |
| 3 | + | |
| 4 | + | |
| 5 | + | |
| 6 | + | |
| 7 | + | |
| 8 | + | |
| 9 | + | |
| 10 | + | |
| 11 | + | |
| 12 | + | |
| 13 | + | |
| 14 | + | |
| 15 | + | |
| 16 | + | |
| 17 | + | |
| 18 | + | |
| 19 | + | |
| 20 | + | |
| 21 | + | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
5 | 5 | | |
6 | 6 | | |
7 | 7 | | |
| 8 | + | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
| 1 | + | |
| 2 | + | |
| 3 | + | |
| 4 | + | |
| 5 | + | |
| 6 | + | |
| 7 | + | |
| 8 | + | |
| 9 | + | |
| 10 | + | |
| 11 | + | |
| 12 | + | |
| 13 | + | |
| 14 | + | |
| 15 | + | |
| 16 | + | |
| 17 | + | |
| 18 | + | |
| 19 | + | |
| 20 | + | |
| 21 | + | |
| 22 | + | |
| 23 | + | |
| 24 | + | |
| 25 | + | |
| 26 | + | |
| 27 | + | |
| 28 | + | |
0 commit comments