@@ -104,9 +104,9 @@ Value getExpandedOffsetWithRange(OpBuilder &builder, const Location &loc,
104104 return expandOffsets (builder, loc, blockShape, offsets, dim);
105105}
106106
107- Value generatePtr (OpBuilder &builder, const Location & loc,
108- ArrayRef<std:: int64_t > blockShape, Descriptor &desc ,
109- ValueRange offsets) {
107+ Value generatePtrFromOffsetRanges (OpBuilder &builder, Location loc,
108+ ArrayRef<int64_t > blockShape,
109+ Descriptor &desc, ValueRange offsets) {
110110 assert (blockShape.size () == desc.shape .size ());
111111 assert (blockShape.size () == offsets.size ());
112112 auto indexTensorType =
@@ -117,15 +117,12 @@ Value generatePtr(OpBuilder &builder, const Location &loc,
117117 // Generate offsets per dimension
118118 Value ptr = builder.create <triton::SplatOp>(loc, ptrTensorType, desc.base );
119119 for (unsigned i = 0 ; i < blockShape.size (); ++i) {
120- auto offsetWithRange =
121- getExpandedOffsetWithRange (builder, loc, blockShape, offsets[i], i);
122-
123120 // We must splat strides into the expanded shape not a row for retaining
124121 // the divisibility information given by strides
125122 Value splatStride = builder.create <triton::SplatOp>(
126- loc, offsetWithRange .getType (), desc.strides [i]);
123+ loc, offsets[i] .getType (), desc.strides [i]);
127124 Value offsetWithStride =
128- builder.create <arith::MulIOp>(loc, offsetWithRange , splatStride);
125+ builder.create <arith::MulIOp>(loc, offsets[i] , splatStride);
129126 Value broadcasted = builder.create <triton::BroadcastOp>(
130127 loc, indexTensorType, offsetWithStride);
131128
@@ -137,32 +134,47 @@ Value generatePtr(OpBuilder &builder, const Location &loc,
137134 return ptr;
138135}
139136
140- Value generateMask (OpBuilder &builder, const Location &loc,
141- ArrayRef<std::int64_t > blockShape, Descriptor &desc,
142- ValueRange offsets) {
137+ Value generatePtr (OpBuilder &builder, const Location &loc,
138+ ArrayRef<std::int64_t > blockShape, Descriptor &desc,
139+ ValueRange offsets) {
143140 assert (blockShape.size () == desc.shape .size ());
144141 assert (blockShape.size () == offsets.size ());
142+ SmallVector<Value> offsetRanges;
143+ for (unsigned i = 0 ; i < blockShape.size (); ++i) {
144+ auto offsetWithRange =
145+ getExpandedOffsetWithRange (builder, loc, blockShape, offsets[i], i);
146+ offsetRanges.push_back (offsetWithRange);
147+ }
148+
149+ return generatePtrFromOffsetRanges (builder, loc, blockShape, desc,
150+ offsetRanges);
151+ }
152+
153+ Value generateMaskFromOffsetRanges (OpBuilder &builder, const Location &loc,
154+ ArrayRef<std::int64_t > blockShape,
155+ Descriptor &desc, ValueRange offsetRanges) {
156+ assert (blockShape.size () == desc.shape .size ());
157+ assert (blockShape.size () == offsetRanges.size ());
145158
146159 // Generate mask per dimension
147160 auto maskTensorType = RankedTensorType::get (blockShape, builder.getI1Type ());
148161 Value mask;
149162 for (std::size_t i = 0 ; i < blockShape.size (); ++i) {
150- auto offsetWithRange =
151- getExpandedOffsetWithRange (builder, loc, blockShape, offsets[i], i);
163+ auto offsetWithRange = offsetRanges[i];
152164
153165 // Compare with lower bound
154166 Value lowerBound = builder.create <mlir::arith::ConstantIntOp>(
155167 loc, 0 , builder.getI64Type ());
156168 Value splatLowerBound = builder.create <triton::SplatOp>(
157- loc, offsetWithRange .getType (), lowerBound);
169+ loc, offsetRanges[i] .getType (), lowerBound);
158170 Value cmpLower = builder.create <arith::CmpIOp>(
159- loc, arith::CmpIPredicate::sge, offsetWithRange , splatLowerBound);
171+ loc, arith::CmpIPredicate::sge, offsetRanges[i] , splatLowerBound);
160172
161173 // Compare with upper bound
162174 Value splatUpperBound = builder.create <triton::SplatOp>(
163- loc, offsetWithRange .getType (), desc.shape [i]);
175+ loc, offsetRanges[i] .getType (), desc.shape [i]);
164176 Value cmpUpper = builder.create <arith::CmpIOp>(
165- loc, arith::CmpIPredicate::slt, offsetWithRange , splatUpperBound);
177+ loc, arith::CmpIPredicate::slt, offsetRanges[i] , splatUpperBound);
166178
167179 // And and broadcast
168180 Value andResult = builder.create <arith::AndIOp>(loc, cmpLower, cmpUpper);
@@ -180,15 +192,35 @@ Value generateMask(OpBuilder &builder, const Location &loc,
180192 return mask;
181193}
182194
183- Value generateOther (OpBuilder &builder, const Location &loc,
184- TensorDescType descTy) {
185- auto scalarTy = descTy.getSignlessBlockType ().getElementType ();
186- auto blockTy =
187- RankedTensorType::get (descTy.getBlockType ().getShape (), scalarTy);
195+ Value generateMask (OpBuilder &builder, const Location &loc,
196+ ArrayRef<std::int64_t > blockShape, Descriptor &desc,
197+ ValueRange offsets) {
198+ assert (blockShape.size () == desc.shape .size ());
199+ assert (blockShape.size () == offsets.size ());
200+ SmallVector<Value> offsetRanges;
201+ for (unsigned i = 0 ; i < blockShape.size (); ++i) {
202+ auto offsetWithRange =
203+ getExpandedOffsetWithRange (builder, loc, blockShape, offsets[i], i);
204+ offsetRanges.push_back (offsetWithRange);
205+ }
206+
207+ return generateMaskFromOffsetRanges (builder, loc, blockShape, desc,
208+ offsetRanges);
209+ }
210+
211+ Value generateOther (OpBuilder &builder, Location loc, Type scalarTy,
212+ ArrayRef<int64_t > blockShape) {
213+ auto blockTy = RankedTensorType::get (blockShape, scalarTy);
188214 auto attr = builder.getZeroAttr (blockTy);
189215 return builder.create <arith::ConstantOp>(loc, attr);
190216}
191217
218+ Value generateOther (OpBuilder &builder, Location loc, TensorDescType descTy) {
219+ auto blockTy = descTy.getSignlessBlockType ();
220+ return generateOther (builder, loc, blockTy.getElementType (),
221+ blockTy.getShape ());
222+ }
223+
192224SmallVector<mlir::Value> castToI64 (OpBuilder &builder,
193225 mlir::ValueRange values) {
194226 auto i64Type = builder.getI64Type ();
@@ -261,6 +293,73 @@ struct RewriteStorePattern : OpConversionPattern<triton::DescriptorStoreOp> {
261293 }
262294};
263295
296+ std::pair<Value, Value>
297+ generateGatherScatterPtrMask (OpBuilder &builder, Location loc,
298+ ArrayRef<int64_t > blockShape, Descriptor &desc,
299+ Value xOffsets, Value yOffset) {
300+ Value xOffsetRange =
301+ expandOffsets (builder, loc, blockShape, xOffsets, /* dim=*/ 0 );
302+ yOffset = castToI64 (builder, {yOffset})[0 ];
303+ auto xOffsetI64Ty = RankedTensorType::get (
304+ cast<RankedTensorType>(xOffsetRange.getType ()).getShape (),
305+ yOffset.getType ());
306+ xOffsetRange =
307+ builder.create <arith::ExtSIOp>(loc, xOffsetI64Ty, xOffsetRange);
308+ auto yOffsetRange =
309+ getExpandedOffsetWithRange (builder, loc, blockShape, yOffset, /* dim=*/ 1 );
310+ auto ptr = generatePtrFromOffsetRanges (builder, loc, blockShape, desc,
311+ {xOffsetRange, yOffsetRange});
312+ auto mask = generateMaskFromOffsetRanges (builder, loc, blockShape, desc,
313+ {xOffsetRange, yOffsetRange});
314+ return {ptr, mask};
315+ }
316+
317+ struct RewriteGatherPattern : OpConversionPattern<triton::DescriptorGatherOp> {
318+ using OpConversionPattern<triton::DescriptorGatherOp>::OpConversionPattern;
319+
320+ llvm::LogicalResult
321+ matchAndRewrite (triton::DescriptorGatherOp op, OneToNOpAdaptor adaptor,
322+ ConversionPatternRewriter &rewriter) const override {
323+ auto loc = op.getLoc ();
324+ auto descTy = op.getDesc ().getType ();
325+ const auto blockShape = op.getResult ().getType ().getShape ();
326+ auto desc = unpackDescriptor (descTy, adaptor.getDesc ());
327+ auto [ptr, mask] = generateGatherScatterPtrMask (
328+ rewriter, loc, blockShape, desc, op.getXOffsets (), op.getYOffset ());
329+ auto other = generateOther (rewriter, loc,
330+ descTy.getSignlessBlockType ().getElementType (),
331+ blockShape);
332+ auto newLoad = rewriter.replaceOpWithNewOp <triton::LoadOp>(
333+ op, ptr, mask, other, triton::CacheModifier::NONE,
334+ triton::EvictionPolicy::NORMAL, false );
335+ newLoad->setAttrs (filterSegmentSizes (op->getAttrs ()));
336+
337+ return llvm::success ();
338+ }
339+ };
340+
341+ struct RewriteScatterPattern
342+ : OpConversionPattern<triton::DescriptorScatterOp> {
343+ using OpConversionPattern<triton::DescriptorScatterOp>::OpConversionPattern;
344+
345+ llvm::LogicalResult
346+ matchAndRewrite (triton::DescriptorScatterOp op, OneToNOpAdaptor adaptor,
347+ ConversionPatternRewriter &rewriter) const override {
348+ auto loc = op.getLoc ();
349+ auto descTy = op.getDesc ().getType ();
350+ const auto blockShape = op.getSrc ().getType ().getShape ();
351+ auto desc = unpackDescriptor (descTy, adaptor.getDesc ());
352+ auto [ptr, mask] = generateGatherScatterPtrMask (
353+ rewriter, loc, blockShape, desc, op.getXOffsets (), op.getYOffset ());
354+ auto newStore = rewriter.replaceOpWithNewOp <triton::StoreOp>(
355+ op, ptr, op.getSrc (), mask, triton::CacheModifier::NONE,
356+ triton::EvictionPolicy::NORMAL);
357+ newStore->setAttrs (filterSegmentSizes (op->getAttrs ()));
358+
359+ return llvm::success ();
360+ }
361+ };
362+
264363/* *
265364 * @brief This implements the pass for converting triton tensor descriptor
266365 * loads/stores into indexed loads/stores.
@@ -329,9 +428,9 @@ class TritonRewriteTensorDescriptorToPointerPass
329428 mlir::scf::populateSCFStructuralTypeConversions (converter, patterns);
330429 triton::populateArithTypeConversions (converter, patterns);
331430
332- patterns
333- . add <RewriteMakeTensorDesc, RewriteLoadPattern, RewriteStorePattern>(
334- converter, &getContext ());
431+ patterns. add <RewriteMakeTensorDesc, RewriteLoadPattern, RewriteStorePattern,
432+ RewriteGatherPattern, RewriteScatterPattern>(converter,
433+ &getContext ());
335434
336435 ConversionConfig config;
337436 config.buildMaterializations = false ;
0 commit comments