Commit 986cea8
authored
Introduce
This PR introduces the `triton-to-unstructured` pass which is the first
step towards allowing triton-shared to compile pointer sequences that
cannot be analyzed by `triton-to-structured` (gather / scatter).
This pass attempts to lower all loads and stores of unstructured
pointers to
tts.gather or tts.scatter that take a single base, a tensor of offsets,
an
optional tensor of mask values, and a default value in case of load.
In addition, all pointer-producing ops will be eliminated and replaced
by
offset-producing ops. tts.gather and tts.scatter will use the pointer
directly from the kernel arguments as opposed to pointer produced by ops
such
as tt.addptr and tt.splat.
Example:
```mlir
module {
tt.func public @gather_simple_no_loop(%arg0: !tt.ptr<f32>, %arg1: !tt.ptr<f32>) attributes {noinline = false} {
%cst = arith.constant dense<5> : tensor<64xi32>
%cst_0 = arith.constant dense<10> : tensor<64xi32>
%0 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
%1 = arith.divsi %0, %cst_0 : tensor<64xi32>
%2 = arith.addi %1, %cst : tensor<64xi32>
%3 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<64x!tt.ptr<f32>>
%4 = tt.addptr %3, %2 : tensor<64x!tt.ptr<f32>>, tensor<64xi32>
%5 = tt.load %4 : tensor<64x!tt.ptr<f32>>
%6 = tt.splat %arg1 : !tt.ptr<f32> -> tensor<64x!tt.ptr<f32>>
%7 = tt.addptr %6, %0 : tensor<64x!tt.ptr<f32>>, tensor<64xi32>
tt.store %7, %5 : tensor<64x!tt.ptr<f32>>
tt.return
}
}
```
becomes
```mlir
module {
tt.func public @gather_simple_no_loop(%arg0: !tt.ptr<f32>, %arg1: !tt.ptr<f32>) attributes {noinline = false} {
%cst = arith.constant dense<5> : tensor<64xi32>
%cst_0 = arith.constant dense<10> : tensor<64xi32>
%0 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
%1 = arith.divsi %0, %cst_0 : tensor<64xi32>
%2 = arith.addi %1, %cst : tensor<64xi32>
%3 = tts.gather %arg0[%2] : (<f32>, tensor<64xi32>) -> tensor<64xf32>
tts.scatter %3 into %arg1[%0] : tensor<64xf32> into (<f32>, tensor<64xi32>)
tt.return
}
}
```
Current assumptions and limitations:
- For simplicity, the pass assumes that gather / scatter operations load
/
store from / to a single base with a tensor of random offsets. As a
result, the following triton program would not work:
```python
@triton.jit
def gather_simple(in0, in1, out0):
offs = tl.arange(0, 8)
in0_ptrs = in0 + offs
in1_ptrs = in1 + offs
ptrs = tl.cat(in0_ptrs, in1_ptrs, can_reorder=True)
c = tl.load(ptrs)
out_offs = tl.arange(0, 16)
tl.store(out0 + out_offs, c)
```
In the above program, `ptrs` contains 2 bases: `in0` and `in1` after the
`cat` operation.
For more details on the algorithm, see the
`TritonToUnstructuredPass.cpp` file.
# Future work
Future work may include scaling the algorithm to support multiple bases
-- one
possible solution is to let tts.gather and tts.scatter take in an
additional
tensor of base pointers corresponding to the tensor of offsets. But
because
we do not want pointer-producing ops to be present after this pass, we
can
use a tensor of index where each element indicates the index of the
pointer
argument to be used. The drawback is a gather or scatter operation now
needs
one extract lookup to get the base which will affect performance.
---
# Intended lowering pipeline
- triton-to-structured (no changes):
- analyzes structured addptr sequences
- introduces `tts.make_tptr %ptr_arg with offsets and strides`
- introduces `tts.load` and `tts.store`
- leaves unstructured addptr sequences and their corresponding `tt.load`
and `tt.store` intact
- triton-to-unstructured (#210):
- introduces `tts.gather` and `tts.scatter`
- removes all pointer-producing ops such as `tt.addptr` and `tt.splat`
and replaces them with offset-producing ops
- structured-to-memref (#217):
- currently converts everything to memref including scalar addptr and
kernel arguments
- will change to just convert ops in the `tts` dialect to `memref` with
the exception of `tts.gather` and `tts.scatter`
- unstructured-to-memref (#216):
- converts the remaining unstructured `tts.gather`, `tts.scatter` into
memref
- triton-ptr-to-memref (#211):
- converts kernel arguments with pointer type to memreftriton-to-unstructured pass (#210)1 parent 8b9f5dd commit 986cea8
File tree
28 files changed
+1809
-89
lines changed- lib
- AnalysisStructured
- Conversion
- TritonToUnstructured
- Dialect/TritonStructured/IR
- test/Conversion/TritonToUnstructured
- tools
28 files changed
+1809
-89
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
2 | 2 | | |
3 | 3 | | |
4 | 4 | | |
| 5 | + | |
5 | 6 | | |
Lines changed: 3 additions & 0 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
| 1 | + | |
| 2 | + | |
| 3 | + | |
Lines changed: 15 additions & 0 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
| 1 | + | |
| 2 | + | |
| 3 | + | |
| 4 | + | |
| 5 | + | |
| 6 | + | |
| 7 | + | |
| 8 | + | |
| 9 | + | |
| 10 | + | |
| 11 | + | |
| 12 | + | |
| 13 | + | |
| 14 | + | |
| 15 | + | |
Lines changed: 15 additions & 0 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
| 1 | + | |
| 2 | + | |
| 3 | + | |
| 4 | + | |
| 5 | + | |
| 6 | + | |
| 7 | + | |
| 8 | + | |
| 9 | + | |
| 10 | + | |
| 11 | + | |
| 12 | + | |
| 13 | + | |
| 14 | + | |
| 15 | + | |
Lines changed: 17 additions & 0 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
| 1 | + | |
| 2 | + | |
| 3 | + | |
| 4 | + | |
| 5 | + | |
| 6 | + | |
| 7 | + | |
| 8 | + | |
| 9 | + | |
| 10 | + | |
| 11 | + | |
| 12 | + | |
| 13 | + | |
| 14 | + | |
| 15 | + | |
| 16 | + | |
| 17 | + | |
Lines changed: 9 additions & 7 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1 | 1 | | |
2 | 2 | | |
3 | 3 | | |
4 | | - | |
5 | | - | |
6 | 4 | | |
7 | 5 | | |
8 | 6 | | |
9 | | - | |
10 | | - | |
11 | | - | |
12 | | - | |
| 7 | + | |
13 | 8 | | |
14 | 9 | | |
15 | | - | |
| 10 | + | |
| 11 | + | |
| 12 | + | |
| 13 | + | |
| 14 | + | |
| 15 | + | |
| 16 | + | |
| 17 | + | |
16 | 18 | | |
17 | 19 | | |
18 | 20 | | |
| |||
Lines changed: 48 additions & 5 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
120 | 120 | | |
121 | 121 | | |
122 | 122 | | |
123 | | - | |
124 | | - | |
125 | 123 | | |
126 | 124 | | |
127 | 125 | | |
| |||
145 | 143 | | |
146 | 144 | | |
147 | 145 | | |
| 146 | + | |
| 147 | + | |
| 148 | + | |
| 149 | + | |
| 150 | + | |
| 151 | + | |
| 152 | + | |
| 153 | + | |
| 154 | + | |
| 155 | + | |
| 156 | + | |
| 157 | + | |
| 158 | + | |
| 159 | + | |
| 160 | + | |
| 161 | + | |
| 162 | + | |
| 163 | + | |
| 164 | + | |
| 165 | + | |
| 166 | + | |
| 167 | + | |
| 168 | + | |
| 169 | + | |
| 170 | + | |
| 171 | + | |
| 172 | + | |
| 173 | + | |
| 174 | + | |
| 175 | + | |
| 176 | + | |
| 177 | + | |
| 178 | + | |
| 179 | + | |
| 180 | + | |
| 181 | + | |
| 182 | + | |
| 183 | + | |
| 184 | + | |
| 185 | + | |
| 186 | + | |
| 187 | + | |
| 188 | + | |
| 189 | + | |
| 190 | + | |
148 | 191 | | |
149 | 192 | | |
150 | 193 | | |
| |||
170 | 213 | | |
171 | 214 | | |
172 | 215 | | |
173 | | - | |
| 216 | + | |
174 | 217 | | |
175 | 218 | | |
176 | 219 | | |
| |||
182 | 225 | | |
183 | 226 | | |
184 | 227 | | |
185 | | - | |
| 228 | + | |
186 | 229 | | |
187 | 230 | | |
188 | 231 | | |
| |||
201 | 244 | | |
202 | 245 | | |
203 | 246 | | |
204 | | - | |
| 247 | + | |
205 | 248 | | |
206 | 249 | | |
207 | 250 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
8 | 8 | | |
9 | 9 | | |
10 | 10 | | |
11 | | - | |
12 | | - | |
13 | 11 | | |
14 | 12 | | |
15 | | - | |
16 | 13 | | |
17 | 14 | | |
18 | 15 | | |
| |||
33 | 30 | | |
34 | 31 | | |
35 | 32 | | |
36 | | - | |
37 | 33 | | |
38 | 34 | | |
39 | 35 | | |
| |||
42 | 38 | | |
43 | 39 | | |
44 | 40 | | |
45 | | - | |
46 | | - | |
47 | | - | |
48 | | - | |
49 | | - | |
50 | | - | |
51 | | - | |
52 | | - | |
53 | | - | |
54 | | - | |
55 | | - | |
56 | | - | |
57 | | - | |
58 | | - | |
59 | | - | |
60 | | - | |
61 | | - | |
62 | | - | |
63 | | - | |
64 | | - | |
65 | | - | |
66 | | - | |
67 | | - | |
68 | | - | |
69 | | - | |
70 | | - | |
71 | | - | |
72 | | - | |
73 | | - | |
74 | | - | |
75 | | - | |
76 | | - | |
77 | | - | |
78 | | - | |
79 | | - | |
80 | | - | |
81 | | - | |
82 | | - | |
83 | | - | |
84 | | - | |
85 | | - | |
86 | | - | |
87 | | - | |
88 | | - | |
89 | | - | |
90 | | - | |
91 | | - | |
92 | | - | |
93 | | - | |
94 | | - | |
95 | | - | |
96 | | - | |
97 | | - | |
98 | | - | |
99 | | - | |
100 | | - | |
101 | | - | |
102 | | - | |
103 | | - | |
104 | | - | |
105 | | - | |
106 | | - | |
107 | | - | |
108 | | - | |
109 | | - | |
110 | | - | |
111 | | - | |
112 | | - | |
113 | 41 | | |
114 | 42 | | |
115 | 43 | | |
| |||
1159 | 1087 | | |
1160 | 1088 | | |
1161 | 1089 | | |
1162 | | - | |
| 1090 | + | |
1163 | 1091 | | |
1164 | 1092 | | |
1165 | 1093 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1 | 1 | | |
2 | 2 | | |
3 | 3 | | |
| 4 | + | |
4 | 5 | | |
5 | 6 | | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
| 1 | + | |
| 2 | + | |
| 3 | + | |
| 4 | + | |
| 5 | + | |
| 6 | + | |
| 7 | + | |
| 8 | + | |
| 9 | + | |
| 10 | + | |
| 11 | + | |
| 12 | + | |
| 13 | + | |
| 14 | + | |
| 15 | + | |
| 16 | + | |
| 17 | + | |
| 18 | + | |
| 19 | + | |
| 20 | + | |
| 21 | + | |
| 22 | + | |
0 commit comments