13
13
#include " mlir/Support/LLVM.h"
14
14
#include " llvm/ADT/ArrayRef.h"
15
15
#include " llvm/ADT/STLExtras.h"
16
- #include " llvm/ADT/Sequence.h"
17
16
#include " llvm/ADT/SmallVector.h"
18
17
#include " llvm/ADT/SmallVectorExtras.h"
19
18
#include " llvm/Support/LogicalResult.h"
25
24
#include < mlir/Transforms/DialectConversion.h>
26
25
27
26
#include < iterator>
28
- #include < memory>
29
27
30
28
namespace mlir ::triton {
31
29
@@ -40,16 +38,6 @@ bool hasATensorDescriptorType(mlir::TypeRange types) {
40
38
});
41
39
}
42
40
43
- /* *
44
- * @brief Convert integer types to signless. Other types are returned as is.
45
- */
46
- mlir::Type toSignlessIntegerType (mlir::Type t) {
47
- if (auto intType = llvm::dyn_cast<mlir::IntegerType>(t)) {
48
- return mlir::IntegerType::get (intType.getContext (), intType.getWidth ());
49
- }
50
- return t;
51
- }
52
-
53
41
using namespace mlir ;
54
42
55
43
/* *
@@ -66,26 +54,29 @@ filterSegmentSizes(mlir::ArrayRef<NamedAttribute> attrs) {
66
54
return ret;
67
55
}
68
56
69
- // Note this has been adapted from RewriteTensorPointer.cpp
70
- Value getExpandedOffsetWithRange (OpBuilder &builder, const Location &loc,
71
- ArrayRef<std::int64_t > blockShape,
72
- ValueRange offsets, unsigned i) {
73
- // Add range
74
- auto indexI32RowType =
75
- RankedTensorType::get ({blockShape[i]}, builder.getI32Type ());
76
- auto indexRowType =
77
- RankedTensorType::get ({blockShape[i]}, builder.getI64Type ());
78
- Value splatOffset =
79
- builder.create <triton::SplatOp>(loc, indexRowType, offsets[i]);
80
- Value range = builder.create <triton::MakeRangeOp>(loc, indexI32RowType, 0 ,
81
- blockShape[i]);
82
- Value i64Range = builder.create <arith::ExtSIOp>(loc, indexRowType, range);
57
+ struct Descriptor {
58
+ Value base;
59
+ ValueRange shape;
60
+ ValueRange strides;
61
+ };
62
+
63
+ Descriptor unpackDescriptor (TensorDescType type, ValueRange pack) {
64
+ int rank = type.getBlockType ().getRank ();
65
+ assert (pack.size () == 1 + 2 * rank && " Expected tensor descriptors to be "
66
+ " broken down into a ptr and "
67
+ " `rank` shapes and `rank` strides" );
68
+ Descriptor res;
69
+ res.base = pack[0 ];
70
+ res.shape = pack.slice (1 , rank);
71
+ res.strides = pack.slice (1 + rank, rank);
72
+ return res;
73
+ }
83
74
84
- // Expand dimensions
85
- Value expandedResult =
86
- builder. create <arith::AddIOp>(loc, splatOffset, i64Range) ;
75
+ Value expandOffsets (OpBuilder &builder, Location loc,
76
+ ArrayRef< int64_t > blockShape, Value offsets, unsigned dim) {
77
+ Value expandedResult = offsets ;
87
78
for (size_t j = 0 ; j < blockShape.size (); ++j) {
88
- if (j == i ) {
79
+ if (j == dim ) {
89
80
continue ;
90
81
}
91
82
expandedResult =
@@ -95,27 +86,44 @@ Value getExpandedOffsetWithRange(OpBuilder &builder, const Location &loc,
95
86
return expandedResult;
96
87
}
97
88
98
- // Note this has been adapted from RewriteTensorPointer.cpp
89
+ Value getExpandedOffsetWithRange (OpBuilder &builder, const Location &loc,
90
+ ArrayRef<std::int64_t > blockShape,
91
+ Value offset, unsigned dim) {
92
+ // Add range
93
+ auto indexI32RowType =
94
+ RankedTensorType::get ({blockShape[dim]}, builder.getI32Type ());
95
+ auto indexRowType =
96
+ RankedTensorType::get ({blockShape[dim]}, builder.getI64Type ());
97
+ Value splatOffset =
98
+ builder.create <triton::SplatOp>(loc, indexRowType, offset);
99
+ Value range = builder.create <triton::MakeRangeOp>(loc, indexI32RowType, 0 ,
100
+ blockShape[dim]);
101
+ Value i64Range = builder.create <arith::ExtSIOp>(loc, indexRowType, range);
102
+
103
+ Value offsets = builder.create <arith::AddIOp>(loc, splatOffset, i64Range);
104
+ return expandOffsets (builder, loc, blockShape, offsets, dim);
105
+ }
106
+
99
107
Value generatePtr (OpBuilder &builder, const Location &loc,
100
- ArrayRef<std::int64_t > blockShape, Value base ,
101
- ValueRange strides, ValueRange offsets) {
102
- assert (blockShape.size () == offsets. size () &&
103
- blockShape.size () == strides .size ());
108
+ ArrayRef<std::int64_t > blockShape, Descriptor &desc ,
109
+ ValueRange offsets) {
110
+ assert (blockShape.size () == desc. shape . size ());
111
+ assert ( blockShape.size () == offsets .size ());
104
112
auto indexTensorType =
105
113
RankedTensorType::get (blockShape, builder.getI64Type ());
106
- auto ptrType = cast<triton::PointerType>(base.getType ());
114
+ auto ptrType = cast<triton::PointerType>(desc. base .getType ());
107
115
auto ptrTensorType = RankedTensorType::get (blockShape, ptrType);
108
116
109
117
// Generate offsets per dimension
110
- Value ptr = builder.create <triton::SplatOp>(loc, ptrTensorType, base);
118
+ Value ptr = builder.create <triton::SplatOp>(loc, ptrTensorType, desc. base );
111
119
for (unsigned i = 0 ; i < blockShape.size (); ++i) {
112
120
auto offsetWithRange =
113
- getExpandedOffsetWithRange (builder, loc, blockShape, offsets, i);
121
+ getExpandedOffsetWithRange (builder, loc, blockShape, offsets[i] , i);
114
122
115
123
// We must splat strides into the expanded shape not a row for retaining
116
124
// the divisibility information given by strides
117
125
Value splatStride = builder.create <triton::SplatOp>(
118
- loc, offsetWithRange.getType (), strides[i]);
126
+ loc, offsetWithRange.getType (), desc. strides [i]);
119
127
Value offsetWithStride =
120
128
builder.create <arith::MulIOp>(loc, offsetWithRange, splatStride);
121
129
Value broadcasted = builder.create <triton::BroadcastOp>(
@@ -129,19 +137,18 @@ Value generatePtr(OpBuilder &builder, const Location &loc,
129
137
return ptr;
130
138
}
131
139
132
- // Note this has been adapted from RewriteTensorPointer.cpp
133
140
Value generateMask (OpBuilder &builder, const Location &loc,
134
- ArrayRef<std::int64_t > blockShape, ValueRange offsets ,
135
- ValueRange shape ) {
136
- assert (blockShape.size () == shape.size () &&
137
- blockShape.size () == offsets.size ());
141
+ ArrayRef<std::int64_t > blockShape, Descriptor &desc ,
142
+ ValueRange offsets ) {
143
+ assert (blockShape.size () == desc. shape .size ());
144
+ assert ( blockShape.size () == offsets.size ());
138
145
139
146
// Generate mask per dimension
140
147
auto maskTensorType = RankedTensorType::get (blockShape, builder.getI1Type ());
141
148
Value mask;
142
149
for (std::size_t i = 0 ; i < blockShape.size (); ++i) {
143
150
auto offsetWithRange =
144
- getExpandedOffsetWithRange (builder, loc, blockShape, offsets, i);
151
+ getExpandedOffsetWithRange (builder, loc, blockShape, offsets[i] , i);
145
152
146
153
// Compare with lower bound
147
154
Value lowerBound = builder.create <mlir::arith::ConstantIntOp>(
@@ -153,7 +160,7 @@ Value generateMask(OpBuilder &builder, const Location &loc,
153
160
154
161
// Compare with upper bound
155
162
Value splatUpperBound = builder.create <triton::SplatOp>(
156
- loc, offsetWithRange.getType (), shape[i]);
163
+ loc, offsetWithRange.getType (), desc. shape [i]);
157
164
Value cmpUpper = builder.create <arith::CmpIOp>(
158
165
loc, arith::CmpIPredicate::slt, offsetWithRange, splatUpperBound);
159
166
@@ -173,20 +180,13 @@ Value generateMask(OpBuilder &builder, const Location &loc,
173
180
return mask;
174
181
}
175
182
176
- // Note this has been adapted from RewriteTensorPointer.cpp. It appears
177
- // to be getting the values used for the masked out elements
178
- Value generateOther (OpBuilder &builder, const Location &loc, Value base,
179
- ArrayRef<std::int64_t > blockShape) {
180
- // Create element attribute
181
- auto elementType = cast<triton::PointerType>(base.getType ()).getPointeeType ();
182
- auto otherTensorType = RankedTensorType::get (blockShape, elementType);
183
-
184
- // Set zero padding value (the default)
185
- TypedAttr attr = builder.getZeroAttr (elementType);
186
-
187
- // Create tensor
188
- Value constant = builder.create <arith::ConstantOp>(loc, attr);
189
- return builder.create <triton::SplatOp>(loc, otherTensorType, constant);
183
+ Value generateOther (OpBuilder &builder, const Location &loc,
184
+ TensorDescType descTy) {
185
+ auto scalarTy = descTy.getSignlessBlockType ().getElementType ();
186
+ auto blockTy =
187
+ RankedTensorType::get (descTy.getBlockType ().getShape (), scalarTy);
188
+ auto attr = builder.getZeroAttr (blockTy);
189
+ return builder.create <arith::ConstantOp>(loc, attr);
190
190
}
191
191
192
192
SmallVector<mlir::Value> castToI64 (OpBuilder &builder,
@@ -201,14 +201,13 @@ struct RewriteMakeTensorDesc : OpConversionPattern<triton::MakeTensorDescOp> {
201
201
using OpConversionPattern<triton::MakeTensorDescOp>::OpConversionPattern;
202
202
203
203
llvm::LogicalResult
204
- matchAndRewrite (triton::MakeTensorDescOp op, OneToNOpAdaptor adaptor,
204
+ matchAndRewrite (triton::MakeTensorDescOp op, OpAdaptor adaptor,
205
205
ConversionPatternRewriter &rewriter) const override {
206
206
SmallVector<mlir::Value> ptrShapeStrides;
207
- // Note that none of these values come from a tensor descriptor so its safe
208
- // to get these directly from the op
209
- llvm::append_values (ptrShapeStrides, op.getBase ());
210
- llvm::append_range (ptrShapeStrides, castToI64 (rewriter, op.getShape ()));
211
- llvm::append_range (ptrShapeStrides, op.getStrides ());
207
+ llvm::append_values (ptrShapeStrides, adaptor.getBase ());
208
+ llvm::append_range (ptrShapeStrides,
209
+ castToI64 (rewriter, adaptor.getShape ()));
210
+ llvm::append_range (ptrShapeStrides, adaptor.getStrides ());
212
211
rewriter.replaceOpWithMultiple (op, {ptrShapeStrides});
213
212
return mlir::success ();
214
213
}
@@ -220,26 +219,19 @@ struct RewriteLoadPattern : OpConversionPattern<triton::DescriptorLoadOp> {
220
219
llvm::LogicalResult
221
220
matchAndRewrite (triton::DescriptorLoadOp op, OneToNOpAdaptor adaptor,
222
221
ConversionPatternRewriter &rewriter) const override {
222
+ auto loc = op.getLoc ();
223
223
const auto blockShape = op.getDesc ().getType ().getBlockType ().getShape ();
224
224
const auto rank = blockShape.size ();
225
- assert (adaptor.getDesc ().size () == 1 + 2 * rank &&
226
- " Expected tensor descriptors to be "
227
- " broken down into a ptr and "
228
- " `rank` shapes and `rank` strides" );
229
-
230
- auto base = adaptor.getDesc ().front ();
231
- auto shape = adaptor.getDesc ().slice (1 , rank);
232
- auto strides = adaptor.getDesc ().slice (1 + rank, rank);
233
- // Note that indices aren't converted so
234
- // we can get them directly here
225
+
226
+ auto descTy = op.getDesc ().getType ();
227
+ auto desc = unpackDescriptor (descTy, adaptor.getDesc ());
235
228
auto offsets = castToI64 (rewriter, op.getIndices ());
236
229
237
230
auto newLoad = rewriter.replaceOpWithNewOp <triton::LoadOp>(
238
- op,
239
- generatePtr (rewriter, op->getLoc (), blockShape, base, strides, offsets),
240
- generateMask (rewriter, op->getLoc (), blockShape, offsets, shape),
241
- generateOther (rewriter, op->getLoc (), base, blockShape),
242
- triton::CacheModifier::NONE, triton::EvictionPolicy::NORMAL, false );
231
+ op, generatePtr (rewriter, loc, blockShape, desc, offsets),
232
+ generateMask (rewriter, loc, blockShape, desc, offsets),
233
+ generateOther (rewriter, loc, descTy), triton::CacheModifier::NONE,
234
+ triton::EvictionPolicy::NORMAL, false );
243
235
newLoad->setAttrs (filterSegmentSizes (op->getAttrs ()));
244
236
245
237
return llvm::success ();
@@ -252,25 +244,16 @@ struct RewriteStorePattern : OpConversionPattern<triton::DescriptorStoreOp> {
252
244
llvm::LogicalResult
253
245
matchAndRewrite (triton::DescriptorStoreOp op, OneToNOpAdaptor adaptor,
254
246
ConversionPatternRewriter &rewriter) const override {
255
- const auto blockShape = op.getDesc ().getType ().getBlockType ().getShape ();
247
+ auto loc = op.getLoc ();
248
+ auto descTy = op.getDesc ().getType ();
249
+ const auto blockShape = descTy.getBlockType ().getShape ();
256
250
const auto rank = blockShape.size ();
257
- assert (adaptor.getDesc ().size () == 1 + 2 * rank &&
258
- " Expected tensor descriptors to be "
259
- " broken down into a ptr and "
260
- " `rank` shapes and `rank` strides" );
261
-
262
- auto base = adaptor.getDesc ().front ();
263
- auto shape = adaptor.getDesc ().slice (1 , rank);
264
- auto strides = adaptor.getDesc ().slice (1 + rank, rank);
265
- // Note that indices aren't converted so
266
- // we can get them directly here
251
+ auto desc = unpackDescriptor (descTy, adaptor.getDesc ());
267
252
auto offsets = castToI64 (rewriter, op.getIndices ());
268
253
269
254
auto newStore = rewriter.replaceOpWithNewOp <triton::StoreOp>(
270
- op,
271
- generatePtr (rewriter, op->getLoc (), blockShape, base, strides, offsets),
272
- op.getSrc (),
273
- generateMask (rewriter, op->getLoc (), blockShape, offsets, shape),
255
+ op, generatePtr (rewriter, loc, blockShape, desc, offsets), op.getSrc (),
256
+ generateMask (rewriter, loc, blockShape, desc, offsets),
274
257
triton::CacheModifier::NONE, triton::EvictionPolicy::NORMAL);
275
258
newStore->setAttrs (filterSegmentSizes (op->getAttrs ()));
276
259
@@ -331,9 +314,8 @@ class TritonRewriteTensorDescriptorToPointerPass
331
314
// for each dimension, i.e., we create 1+2*rank values. Note that tensor
332
315
// descriptors may be signed/unsigned integers whereas pointers should
333
316
// always be signless.
334
- auto tensorType = t.getBlockType ();
335
- out.push_back (triton::getPointerType (
336
- toSignlessIntegerType (tensorType.getElementType ())));
317
+ auto tensorType = t.getSignlessBlockType ();
318
+ out.push_back (triton::getPointerType (tensorType.getElementType ()));
337
319
out.insert (out.end (), 2 * tensorType.getRank (),
338
320
mlir::IntegerType::get (t.getContext (), 64 ));
339
321
return mlir::success ();
0 commit comments