@@ -104,9 +104,9 @@ Value getExpandedOffsetWithRange(OpBuilder &builder, const Location &loc,
104
104
return expandOffsets (builder, loc, blockShape, offsets, dim);
105
105
}
106
106
107
- Value generatePtr (OpBuilder &builder, const Location & loc,
108
- ArrayRef<std:: int64_t > blockShape, Descriptor &desc ,
109
- ValueRange offsets) {
107
+ Value generatePtrFromOffsetRanges (OpBuilder &builder, Location loc,
108
+ ArrayRef<int64_t > blockShape,
109
+ Descriptor &desc, ValueRange offsets) {
110
110
assert (blockShape.size () == desc.shape .size ());
111
111
assert (blockShape.size () == offsets.size ());
112
112
auto indexTensorType =
@@ -117,15 +117,12 @@ Value generatePtr(OpBuilder &builder, const Location &loc,
117
117
// Generate offsets per dimension
118
118
Value ptr = builder.create <triton::SplatOp>(loc, ptrTensorType, desc.base );
119
119
for (unsigned i = 0 ; i < blockShape.size (); ++i) {
120
- auto offsetWithRange =
121
- getExpandedOffsetWithRange (builder, loc, blockShape, offsets[i], i);
122
-
123
120
// We must splat strides into the expanded shape not a row for retaining
124
121
// the divisibility information given by strides
125
122
Value splatStride = builder.create <triton::SplatOp>(
126
- loc, offsetWithRange .getType (), desc.strides [i]);
123
+ loc, offsets[i] .getType (), desc.strides [i]);
127
124
Value offsetWithStride =
128
- builder.create <arith::MulIOp>(loc, offsetWithRange , splatStride);
125
+ builder.create <arith::MulIOp>(loc, offsets[i] , splatStride);
129
126
Value broadcasted = builder.create <triton::BroadcastOp>(
130
127
loc, indexTensorType, offsetWithStride);
131
128
@@ -137,32 +134,47 @@ Value generatePtr(OpBuilder &builder, const Location &loc,
137
134
return ptr;
138
135
}
139
136
140
- Value generateMask (OpBuilder &builder, const Location &loc,
141
- ArrayRef<std::int64_t > blockShape, Descriptor &desc,
142
- ValueRange offsets) {
137
+ Value generatePtr (OpBuilder &builder, const Location &loc,
138
+ ArrayRef<std::int64_t > blockShape, Descriptor &desc,
139
+ ValueRange offsets) {
143
140
assert (blockShape.size () == desc.shape .size ());
144
141
assert (blockShape.size () == offsets.size ());
142
+ SmallVector<Value> offsetRanges;
143
+ for (unsigned i = 0 ; i < blockShape.size (); ++i) {
144
+ auto offsetWithRange =
145
+ getExpandedOffsetWithRange (builder, loc, blockShape, offsets[i], i);
146
+ offsetRanges.push_back (offsetWithRange);
147
+ }
148
+
149
+ return generatePtrFromOffsetRanges (builder, loc, blockShape, desc,
150
+ offsetRanges);
151
+ }
152
+
153
+ Value generateMaskFromOffsetRanges (OpBuilder &builder, const Location &loc,
154
+ ArrayRef<std::int64_t > blockShape,
155
+ Descriptor &desc, ValueRange offsetRanges) {
156
+ assert (blockShape.size () == desc.shape .size ());
157
+ assert (blockShape.size () == offsetRanges.size ());
145
158
146
159
// Generate mask per dimension
147
160
auto maskTensorType = RankedTensorType::get (blockShape, builder.getI1Type ());
148
161
Value mask;
149
162
for (std::size_t i = 0 ; i < blockShape.size (); ++i) {
150
- auto offsetWithRange =
151
- getExpandedOffsetWithRange (builder, loc, blockShape, offsets[i], i);
163
+ auto offsetWithRange = offsetRanges[i];
152
164
153
165
// Compare with lower bound
154
166
Value lowerBound = builder.create <mlir::arith::ConstantIntOp>(
155
167
loc, 0 , builder.getI64Type ());
156
168
Value splatLowerBound = builder.create <triton::SplatOp>(
157
- loc, offsetWithRange .getType (), lowerBound);
169
+ loc, offsetRanges[i] .getType (), lowerBound);
158
170
Value cmpLower = builder.create <arith::CmpIOp>(
159
- loc, arith::CmpIPredicate::sge, offsetWithRange , splatLowerBound);
171
+ loc, arith::CmpIPredicate::sge, offsetRanges[i] , splatLowerBound);
160
172
161
173
// Compare with upper bound
162
174
Value splatUpperBound = builder.create <triton::SplatOp>(
163
- loc, offsetWithRange .getType (), desc.shape [i]);
175
+ loc, offsetRanges[i] .getType (), desc.shape [i]);
164
176
Value cmpUpper = builder.create <arith::CmpIOp>(
165
- loc, arith::CmpIPredicate::slt, offsetWithRange , splatUpperBound);
177
+ loc, arith::CmpIPredicate::slt, offsetRanges[i] , splatUpperBound);
166
178
167
179
// And and broadcast
168
180
Value andResult = builder.create <arith::AndIOp>(loc, cmpLower, cmpUpper);
@@ -180,15 +192,35 @@ Value generateMask(OpBuilder &builder, const Location &loc,
180
192
return mask;
181
193
}
182
194
183
- Value generateOther (OpBuilder &builder, const Location &loc,
184
- TensorDescType descTy) {
185
- auto scalarTy = descTy.getSignlessBlockType ().getElementType ();
186
- auto blockTy =
187
- RankedTensorType::get (descTy.getBlockType ().getShape (), scalarTy);
195
+ Value generateMask (OpBuilder &builder, const Location &loc,
196
+ ArrayRef<std::int64_t > blockShape, Descriptor &desc,
197
+ ValueRange offsets) {
198
+ assert (blockShape.size () == desc.shape .size ());
199
+ assert (blockShape.size () == offsets.size ());
200
+ SmallVector<Value> offsetRanges;
201
+ for (unsigned i = 0 ; i < blockShape.size (); ++i) {
202
+ auto offsetWithRange =
203
+ getExpandedOffsetWithRange (builder, loc, blockShape, offsets[i], i);
204
+ offsetRanges.push_back (offsetWithRange);
205
+ }
206
+
207
+ return generateMaskFromOffsetRanges (builder, loc, blockShape, desc,
208
+ offsetRanges);
209
+ }
210
+
211
+ Value generateOther (OpBuilder &builder, Location loc, Type scalarTy,
212
+ ArrayRef<int64_t > blockShape) {
213
+ auto blockTy = RankedTensorType::get (blockShape, scalarTy);
188
214
auto attr = builder.getZeroAttr (blockTy);
189
215
return builder.create <arith::ConstantOp>(loc, attr);
190
216
}
191
217
218
+ Value generateOther (OpBuilder &builder, Location loc, TensorDescType descTy) {
219
+ auto blockTy = descTy.getSignlessBlockType ();
220
+ return generateOther (builder, loc, blockTy.getElementType (),
221
+ blockTy.getShape ());
222
+ }
223
+
192
224
SmallVector<mlir::Value> castToI64 (OpBuilder &builder,
193
225
mlir::ValueRange values) {
194
226
auto i64Type = builder.getI64Type ();
@@ -261,6 +293,73 @@ struct RewriteStorePattern : OpConversionPattern<triton::DescriptorStoreOp> {
261
293
}
262
294
};
263
295
296
+ std::pair<Value, Value>
297
+ generateGatherScatterPtrMask (OpBuilder &builder, Location loc,
298
+ ArrayRef<int64_t > blockShape, Descriptor &desc,
299
+ Value xOffsets, Value yOffset) {
300
+ Value xOffsetRange =
301
+ expandOffsets (builder, loc, blockShape, xOffsets, /* dim=*/ 0 );
302
+ yOffset = castToI64 (builder, {yOffset})[0 ];
303
+ auto xOffsetI64Ty = RankedTensorType::get (
304
+ cast<RankedTensorType>(xOffsetRange.getType ()).getShape (),
305
+ yOffset.getType ());
306
+ xOffsetRange =
307
+ builder.create <arith::ExtSIOp>(loc, xOffsetI64Ty, xOffsetRange);
308
+ auto yOffsetRange =
309
+ getExpandedOffsetWithRange (builder, loc, blockShape, yOffset, /* dim=*/ 1 );
310
+ auto ptr = generatePtrFromOffsetRanges (builder, loc, blockShape, desc,
311
+ {xOffsetRange, yOffsetRange});
312
+ auto mask = generateMaskFromOffsetRanges (builder, loc, blockShape, desc,
313
+ {xOffsetRange, yOffsetRange});
314
+ return {ptr, mask};
315
+ }
316
+
317
+ struct RewriteGatherPattern : OpConversionPattern<triton::DescriptorGatherOp> {
318
+ using OpConversionPattern<triton::DescriptorGatherOp>::OpConversionPattern;
319
+
320
+ llvm::LogicalResult
321
+ matchAndRewrite (triton::DescriptorGatherOp op, OneToNOpAdaptor adaptor,
322
+ ConversionPatternRewriter &rewriter) const override {
323
+ auto loc = op.getLoc ();
324
+ auto descTy = op.getDesc ().getType ();
325
+ const auto blockShape = op.getResult ().getType ().getShape ();
326
+ auto desc = unpackDescriptor (descTy, adaptor.getDesc ());
327
+ auto [ptr, mask] = generateGatherScatterPtrMask (
328
+ rewriter, loc, blockShape, desc, op.getXOffsets (), op.getYOffset ());
329
+ auto other = generateOther (rewriter, loc,
330
+ descTy.getSignlessBlockType ().getElementType (),
331
+ blockShape);
332
+ auto newLoad = rewriter.replaceOpWithNewOp <triton::LoadOp>(
333
+ op, ptr, mask, other, triton::CacheModifier::NONE,
334
+ triton::EvictionPolicy::NORMAL, false );
335
+ newLoad->setAttrs (filterSegmentSizes (op->getAttrs ()));
336
+
337
+ return llvm::success ();
338
+ }
339
+ };
340
+
341
+ struct RewriteScatterPattern
342
+ : OpConversionPattern<triton::DescriptorScatterOp> {
343
+ using OpConversionPattern<triton::DescriptorScatterOp>::OpConversionPattern;
344
+
345
+ llvm::LogicalResult
346
+ matchAndRewrite (triton::DescriptorScatterOp op, OneToNOpAdaptor adaptor,
347
+ ConversionPatternRewriter &rewriter) const override {
348
+ auto loc = op.getLoc ();
349
+ auto descTy = op.getDesc ().getType ();
350
+ const auto blockShape = op.getSrc ().getType ().getShape ();
351
+ auto desc = unpackDescriptor (descTy, adaptor.getDesc ());
352
+ auto [ptr, mask] = generateGatherScatterPtrMask (
353
+ rewriter, loc, blockShape, desc, op.getXOffsets (), op.getYOffset ());
354
+ auto newStore = rewriter.replaceOpWithNewOp <triton::StoreOp>(
355
+ op, ptr, op.getSrc (), mask, triton::CacheModifier::NONE,
356
+ triton::EvictionPolicy::NORMAL);
357
+ newStore->setAttrs (filterSegmentSizes (op->getAttrs ()));
358
+
359
+ return llvm::success ();
360
+ }
361
+ };
362
+
264
363
/* *
265
364
* @brief This implements the pass for converting triton tensor descriptor
266
365
* loads/stores into indexed loads/stores.
@@ -329,9 +428,9 @@ class TritonRewriteTensorDescriptorToPointerPass
329
428
mlir::scf::populateSCFStructuralTypeConversions (converter, patterns);
330
429
triton::populateArithTypeConversions (converter, patterns);
331
430
332
- patterns
333
- . add <RewriteMakeTensorDesc, RewriteLoadPattern, RewriteStorePattern>(
334
- converter, &getContext ());
431
+ patterns. add <RewriteMakeTensorDesc, RewriteLoadPattern, RewriteStorePattern,
432
+ RewriteGatherPattern, RewriteScatterPattern>(converter,
433
+ &getContext ());
335
434
336
435
ConversionConfig config;
337
436
config.buildMaterializations = false ;
0 commit comments