1313#include " mlir/Support/LLVM.h"
1414#include " llvm/ADT/ArrayRef.h"
1515#include " llvm/ADT/STLExtras.h"
16- #include " llvm/ADT/Sequence.h"
1716#include " llvm/ADT/SmallVector.h"
1817#include " llvm/ADT/SmallVectorExtras.h"
1918#include " llvm/Support/LogicalResult.h"
2524#include < mlir/Transforms/DialectConversion.h>
2625
2726#include < iterator>
28- #include < memory>
2927
3028namespace mlir ::triton {
3129
@@ -40,16 +38,6 @@ bool hasATensorDescriptorType(mlir::TypeRange types) {
4038 });
4139}
4240
43- /* *
44- * @brief Convert integer types to signless. Other types are returned as is.
45- */
46- mlir::Type toSignlessIntegerType (mlir::Type t) {
47- if (auto intType = llvm::dyn_cast<mlir::IntegerType>(t)) {
48- return mlir::IntegerType::get (intType.getContext (), intType.getWidth ());
49- }
50- return t;
51- }
52-
5341using namespace mlir ;
5442
5543/* *
@@ -66,26 +54,29 @@ filterSegmentSizes(mlir::ArrayRef<NamedAttribute> attrs) {
6654 return ret;
6755}
6856
69- // Note this has been adapted from RewriteTensorPointer.cpp
70- Value getExpandedOffsetWithRange (OpBuilder &builder, const Location &loc,
71- ArrayRef<std::int64_t > blockShape,
72- ValueRange offsets, unsigned i) {
73- // Add range
74- auto indexI32RowType =
75- RankedTensorType::get ({blockShape[i]}, builder.getI32Type ());
76- auto indexRowType =
77- RankedTensorType::get ({blockShape[i]}, builder.getI64Type ());
78- Value splatOffset =
79- builder.create <triton::SplatOp>(loc, indexRowType, offsets[i]);
80- Value range = builder.create <triton::MakeRangeOp>(loc, indexI32RowType, 0 ,
81- blockShape[i]);
82- Value i64Range = builder.create <arith::ExtSIOp>(loc, indexRowType, range);
57+ struct Descriptor {
58+ Value base;
59+ ValueRange shape;
60+ ValueRange strides;
61+ };
62+
63+ Descriptor unpackDescriptor (TensorDescType type, ValueRange pack) {
64+ int rank = type.getBlockType ().getRank ();
65+ assert (pack.size () == 1 + 2 * rank && " Expected tensor descriptors to be "
66+ " broken down into a ptr and "
67+ " `rank` shapes and `rank` strides" );
68+ Descriptor res;
69+ res.base = pack[0 ];
70+ res.shape = pack.slice (1 , rank);
71+ res.strides = pack.slice (1 + rank, rank);
72+ return res;
73+ }
8374
84- // Expand dimensions
85- Value expandedResult =
86- builder. create <arith::AddIOp>(loc, splatOffset, i64Range) ;
75+ Value expandOffsets (OpBuilder &builder, Location loc,
76+ ArrayRef< int64_t > blockShape, Value offsets, unsigned dim) {
77+ Value expandedResult = offsets ;
8778 for (size_t j = 0 ; j < blockShape.size (); ++j) {
88- if (j == i ) {
79+ if (j == dim ) {
8980 continue ;
9081 }
9182 expandedResult =
@@ -95,27 +86,44 @@ Value getExpandedOffsetWithRange(OpBuilder &builder, const Location &loc,
9586 return expandedResult;
9687}
9788
98- // Note this has been adapted from RewriteTensorPointer.cpp
89+ Value getExpandedOffsetWithRange (OpBuilder &builder, const Location &loc,
90+ ArrayRef<std::int64_t > blockShape,
91+ Value offset, unsigned dim) {
92+ // Add range
93+ auto indexI32RowType =
94+ RankedTensorType::get ({blockShape[dim]}, builder.getI32Type ());
95+ auto indexRowType =
96+ RankedTensorType::get ({blockShape[dim]}, builder.getI64Type ());
97+ Value splatOffset =
98+ builder.create <triton::SplatOp>(loc, indexRowType, offset);
99+ Value range = builder.create <triton::MakeRangeOp>(loc, indexI32RowType, 0 ,
100+ blockShape[dim]);
101+ Value i64Range = builder.create <arith::ExtSIOp>(loc, indexRowType, range);
102+
103+ Value offsets = builder.create <arith::AddIOp>(loc, splatOffset, i64Range);
104+ return expandOffsets (builder, loc, blockShape, offsets, dim);
105+ }
106+
99107Value generatePtr (OpBuilder &builder, const Location &loc,
100- ArrayRef<std::int64_t > blockShape, Value base ,
101- ValueRange strides, ValueRange offsets) {
102- assert (blockShape.size () == offsets. size () &&
103- blockShape.size () == strides .size ());
108+ ArrayRef<std::int64_t > blockShape, Descriptor &desc ,
109+ ValueRange offsets) {
110+ assert (blockShape.size () == desc. shape . size ());
111+ assert ( blockShape.size () == offsets .size ());
104112 auto indexTensorType =
105113 RankedTensorType::get (blockShape, builder.getI64Type ());
106- auto ptrType = cast<triton::PointerType>(base.getType ());
114+ auto ptrType = cast<triton::PointerType>(desc. base .getType ());
107115 auto ptrTensorType = RankedTensorType::get (blockShape, ptrType);
108116
109117 // Generate offsets per dimension
110- Value ptr = builder.create <triton::SplatOp>(loc, ptrTensorType, base);
118+ Value ptr = builder.create <triton::SplatOp>(loc, ptrTensorType, desc. base );
111119 for (unsigned i = 0 ; i < blockShape.size (); ++i) {
112120 auto offsetWithRange =
113- getExpandedOffsetWithRange (builder, loc, blockShape, offsets, i);
121+ getExpandedOffsetWithRange (builder, loc, blockShape, offsets[i] , i);
114122
115123 // We must splat strides into the expanded shape not a row for retaining
116124 // the divisibility information given by strides
117125 Value splatStride = builder.create <triton::SplatOp>(
118- loc, offsetWithRange.getType (), strides[i]);
126+ loc, offsetWithRange.getType (), desc. strides [i]);
119127 Value offsetWithStride =
120128 builder.create <arith::MulIOp>(loc, offsetWithRange, splatStride);
121129 Value broadcasted = builder.create <triton::BroadcastOp>(
@@ -129,19 +137,18 @@ Value generatePtr(OpBuilder &builder, const Location &loc,
129137 return ptr;
130138}
131139
132- // Note this has been adapted from RewriteTensorPointer.cpp
133140Value generateMask (OpBuilder &builder, const Location &loc,
134- ArrayRef<std::int64_t > blockShape, ValueRange offsets ,
135- ValueRange shape ) {
136- assert (blockShape.size () == shape.size () &&
137- blockShape.size () == offsets.size ());
141+ ArrayRef<std::int64_t > blockShape, Descriptor &desc ,
142+ ValueRange offsets ) {
143+ assert (blockShape.size () == desc. shape .size ());
144+ assert ( blockShape.size () == offsets.size ());
138145
139146 // Generate mask per dimension
140147 auto maskTensorType = RankedTensorType::get (blockShape, builder.getI1Type ());
141148 Value mask;
142149 for (std::size_t i = 0 ; i < blockShape.size (); ++i) {
143150 auto offsetWithRange =
144- getExpandedOffsetWithRange (builder, loc, blockShape, offsets, i);
151+ getExpandedOffsetWithRange (builder, loc, blockShape, offsets[i] , i);
145152
146153 // Compare with lower bound
147154 Value lowerBound = builder.create <mlir::arith::ConstantIntOp>(
@@ -153,7 +160,7 @@ Value generateMask(OpBuilder &builder, const Location &loc,
153160
154161 // Compare with upper bound
155162 Value splatUpperBound = builder.create <triton::SplatOp>(
156- loc, offsetWithRange.getType (), shape[i]);
163+ loc, offsetWithRange.getType (), desc. shape [i]);
157164 Value cmpUpper = builder.create <arith::CmpIOp>(
158165 loc, arith::CmpIPredicate::slt, offsetWithRange, splatUpperBound);
159166
@@ -173,20 +180,13 @@ Value generateMask(OpBuilder &builder, const Location &loc,
173180 return mask;
174181}
175182
176- // Note this has been adapted from RewriteTensorPointer.cpp. It appears
177- // to be getting the values used for the masked out elements
178- Value generateOther (OpBuilder &builder, const Location &loc, Value base,
179- ArrayRef<std::int64_t > blockShape) {
180- // Create element attribute
181- auto elementType = cast<triton::PointerType>(base.getType ()).getPointeeType ();
182- auto otherTensorType = RankedTensorType::get (blockShape, elementType);
183-
184- // Set zero padding value (the default)
185- TypedAttr attr = builder.getZeroAttr (elementType);
186-
187- // Create tensor
188- Value constant = builder.create <arith::ConstantOp>(loc, attr);
189- return builder.create <triton::SplatOp>(loc, otherTensorType, constant);
183+ Value generateOther (OpBuilder &builder, const Location &loc,
184+ TensorDescType descTy) {
185+ auto scalarTy = descTy.getSignlessBlockType ().getElementType ();
186+ auto blockTy =
187+ RankedTensorType::get (descTy.getBlockType ().getShape (), scalarTy);
188+ auto attr = builder.getZeroAttr (blockTy);
189+ return builder.create <arith::ConstantOp>(loc, attr);
190190}
191191
192192SmallVector<mlir::Value> castToI64 (OpBuilder &builder,
@@ -201,14 +201,13 @@ struct RewriteMakeTensorDesc : OpConversionPattern<triton::MakeTensorDescOp> {
201201 using OpConversionPattern<triton::MakeTensorDescOp>::OpConversionPattern;
202202
203203 llvm::LogicalResult
204- matchAndRewrite (triton::MakeTensorDescOp op, OneToNOpAdaptor adaptor,
204+ matchAndRewrite (triton::MakeTensorDescOp op, OpAdaptor adaptor,
205205 ConversionPatternRewriter &rewriter) const override {
206206 SmallVector<mlir::Value> ptrShapeStrides;
207- // Note that none of these values come from a tensor descriptor so its safe
208- // to get these directly from the op
209- llvm::append_values (ptrShapeStrides, op.getBase ());
210- llvm::append_range (ptrShapeStrides, castToI64 (rewriter, op.getShape ()));
211- llvm::append_range (ptrShapeStrides, op.getStrides ());
207+ llvm::append_values (ptrShapeStrides, adaptor.getBase ());
208+ llvm::append_range (ptrShapeStrides,
209+ castToI64 (rewriter, adaptor.getShape ()));
210+ llvm::append_range (ptrShapeStrides, adaptor.getStrides ());
212211 rewriter.replaceOpWithMultiple (op, {ptrShapeStrides});
213212 return mlir::success ();
214213 }
@@ -220,26 +219,19 @@ struct RewriteLoadPattern : OpConversionPattern<triton::DescriptorLoadOp> {
220219 llvm::LogicalResult
221220 matchAndRewrite (triton::DescriptorLoadOp op, OneToNOpAdaptor adaptor,
222221 ConversionPatternRewriter &rewriter) const override {
222+ auto loc = op.getLoc ();
223223 const auto blockShape = op.getDesc ().getType ().getBlockType ().getShape ();
224224 const auto rank = blockShape.size ();
225- assert (adaptor.getDesc ().size () == 1 + 2 * rank &&
226- " Expected tensor descriptors to be "
227- " broken down into a ptr and "
228- " `rank` shapes and `rank` strides" );
229-
230- auto base = adaptor.getDesc ().front ();
231- auto shape = adaptor.getDesc ().slice (1 , rank);
232- auto strides = adaptor.getDesc ().slice (1 + rank, rank);
233- // Note that indices aren't converted so
234- // we can get them directly here
225+
226+ auto descTy = op.getDesc ().getType ();
227+ auto desc = unpackDescriptor (descTy, adaptor.getDesc ());
235228 auto offsets = castToI64 (rewriter, op.getIndices ());
236229
237230 auto newLoad = rewriter.replaceOpWithNewOp <triton::LoadOp>(
238- op,
239- generatePtr (rewriter, op->getLoc (), blockShape, base, strides, offsets),
240- generateMask (rewriter, op->getLoc (), blockShape, offsets, shape),
241- generateOther (rewriter, op->getLoc (), base, blockShape),
242- triton::CacheModifier::NONE, triton::EvictionPolicy::NORMAL, false );
231+ op, generatePtr (rewriter, loc, blockShape, desc, offsets),
232+ generateMask (rewriter, loc, blockShape, desc, offsets),
233+ generateOther (rewriter, loc, descTy), triton::CacheModifier::NONE,
234+ triton::EvictionPolicy::NORMAL, false );
243235 newLoad->setAttrs (filterSegmentSizes (op->getAttrs ()));
244236
245237 return llvm::success ();
@@ -252,25 +244,16 @@ struct RewriteStorePattern : OpConversionPattern<triton::DescriptorStoreOp> {
252244 llvm::LogicalResult
253245 matchAndRewrite (triton::DescriptorStoreOp op, OneToNOpAdaptor adaptor,
254246 ConversionPatternRewriter &rewriter) const override {
255- const auto blockShape = op.getDesc ().getType ().getBlockType ().getShape ();
247+ auto loc = op.getLoc ();
248+ auto descTy = op.getDesc ().getType ();
249+ const auto blockShape = descTy.getBlockType ().getShape ();
256250 const auto rank = blockShape.size ();
257- assert (adaptor.getDesc ().size () == 1 + 2 * rank &&
258- " Expected tensor descriptors to be "
259- " broken down into a ptr and "
260- " `rank` shapes and `rank` strides" );
261-
262- auto base = adaptor.getDesc ().front ();
263- auto shape = adaptor.getDesc ().slice (1 , rank);
264- auto strides = adaptor.getDesc ().slice (1 + rank, rank);
265- // Note that indices aren't converted so
266- // we can get them directly here
251+ auto desc = unpackDescriptor (descTy, adaptor.getDesc ());
267252 auto offsets = castToI64 (rewriter, op.getIndices ());
268253
269254 auto newStore = rewriter.replaceOpWithNewOp <triton::StoreOp>(
270- op,
271- generatePtr (rewriter, op->getLoc (), blockShape, base, strides, offsets),
272- op.getSrc (),
273- generateMask (rewriter, op->getLoc (), blockShape, offsets, shape),
255+ op, generatePtr (rewriter, loc, blockShape, desc, offsets), op.getSrc (),
256+ generateMask (rewriter, loc, blockShape, desc, offsets),
274257 triton::CacheModifier::NONE, triton::EvictionPolicy::NORMAL);
275258 newStore->setAttrs (filterSegmentSizes (op->getAttrs ()));
276259
@@ -331,9 +314,8 @@ class TritonRewriteTensorDescriptorToPointerPass
331314 // for each dimension, i.e., we create 1+2*rank values. Note that tensor
332315 // descriptors may be signed/unsigned integers whereas pointers should
333316 // always be signless.
334- auto tensorType = t.getBlockType ();
335- out.push_back (triton::getPointerType (
336- toSignlessIntegerType (tensorType.getElementType ())));
317+ auto tensorType = t.getSignlessBlockType ();
318+ out.push_back (triton::getPointerType (tensorType.getElementType ()));
337319 out.insert (out.end (), 2 * tensorType.getRank (),
338320 mlir::IntegerType::get (t.getContext (), 64 ));
339321 return mlir::success ();
0 commit comments