Skip to content

Commit 968ec0a

Browse files
authored
[NFC] Small cleanup of tensor descriptor rewrite pass (#6821)
1 parent 913218e commit 968ec0a

File tree

1 file changed

+81
-99
lines changed

1 file changed

+81
-99
lines changed

lib/Dialect/Triton/Transforms/RewriteTensorDescriptorToPointer.cpp

Lines changed: 81 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
#include "mlir/Support/LLVM.h"
1414
#include "llvm/ADT/ArrayRef.h"
1515
#include "llvm/ADT/STLExtras.h"
16-
#include "llvm/ADT/Sequence.h"
1716
#include "llvm/ADT/SmallVector.h"
1817
#include "llvm/ADT/SmallVectorExtras.h"
1918
#include "llvm/Support/LogicalResult.h"
@@ -25,7 +24,6 @@
2524
#include <mlir/Transforms/DialectConversion.h>
2625

2726
#include <iterator>
28-
#include <memory>
2927

3028
namespace mlir::triton {
3129

@@ -40,16 +38,6 @@ bool hasATensorDescriptorType(mlir::TypeRange types) {
4038
});
4139
}
4240

43-
/**
44-
* @brief Convert integer types to signless. Other types are returned as is.
45-
*/
46-
mlir::Type toSignlessIntegerType(mlir::Type t) {
47-
if (auto intType = llvm::dyn_cast<mlir::IntegerType>(t)) {
48-
return mlir::IntegerType::get(intType.getContext(), intType.getWidth());
49-
}
50-
return t;
51-
}
52-
5341
using namespace mlir;
5442

5543
/**
@@ -66,26 +54,29 @@ filterSegmentSizes(mlir::ArrayRef<NamedAttribute> attrs) {
6654
return ret;
6755
}
6856

69-
// Note this has been adapted from RewriteTensorPointer.cpp
70-
Value getExpandedOffsetWithRange(OpBuilder &builder, const Location &loc,
71-
ArrayRef<std::int64_t> blockShape,
72-
ValueRange offsets, unsigned i) {
73-
// Add range
74-
auto indexI32RowType =
75-
RankedTensorType::get({blockShape[i]}, builder.getI32Type());
76-
auto indexRowType =
77-
RankedTensorType::get({blockShape[i]}, builder.getI64Type());
78-
Value splatOffset =
79-
builder.create<triton::SplatOp>(loc, indexRowType, offsets[i]);
80-
Value range = builder.create<triton::MakeRangeOp>(loc, indexI32RowType, 0,
81-
blockShape[i]);
82-
Value i64Range = builder.create<arith::ExtSIOp>(loc, indexRowType, range);
57+
struct Descriptor {
58+
Value base;
59+
ValueRange shape;
60+
ValueRange strides;
61+
};
62+
63+
Descriptor unpackDescriptor(TensorDescType type, ValueRange pack) {
64+
int rank = type.getBlockType().getRank();
65+
assert(pack.size() == 1 + 2 * rank && "Expected tensor descriptors to be "
66+
"broken down into a ptr and "
67+
"`rank` shapes and `rank` strides");
68+
Descriptor res;
69+
res.base = pack[0];
70+
res.shape = pack.slice(1, rank);
71+
res.strides = pack.slice(1 + rank, rank);
72+
return res;
73+
}
8374

84-
// Expand dimensions
85-
Value expandedResult =
86-
builder.create<arith::AddIOp>(loc, splatOffset, i64Range);
75+
Value expandOffsets(OpBuilder &builder, Location loc,
76+
ArrayRef<int64_t> blockShape, Value offsets, unsigned dim) {
77+
Value expandedResult = offsets;
8778
for (size_t j = 0; j < blockShape.size(); ++j) {
88-
if (j == i) {
79+
if (j == dim) {
8980
continue;
9081
}
9182
expandedResult =
@@ -95,27 +86,44 @@ Value getExpandedOffsetWithRange(OpBuilder &builder, const Location &loc,
9586
return expandedResult;
9687
}
9788

98-
// Note this has been adapted from RewriteTensorPointer.cpp
89+
Value getExpandedOffsetWithRange(OpBuilder &builder, const Location &loc,
90+
ArrayRef<std::int64_t> blockShape,
91+
Value offset, unsigned dim) {
92+
// Add range
93+
auto indexI32RowType =
94+
RankedTensorType::get({blockShape[dim]}, builder.getI32Type());
95+
auto indexRowType =
96+
RankedTensorType::get({blockShape[dim]}, builder.getI64Type());
97+
Value splatOffset =
98+
builder.create<triton::SplatOp>(loc, indexRowType, offset);
99+
Value range = builder.create<triton::MakeRangeOp>(loc, indexI32RowType, 0,
100+
blockShape[dim]);
101+
Value i64Range = builder.create<arith::ExtSIOp>(loc, indexRowType, range);
102+
103+
Value offsets = builder.create<arith::AddIOp>(loc, splatOffset, i64Range);
104+
return expandOffsets(builder, loc, blockShape, offsets, dim);
105+
}
106+
99107
Value generatePtr(OpBuilder &builder, const Location &loc,
100-
ArrayRef<std::int64_t> blockShape, Value base,
101-
ValueRange strides, ValueRange offsets) {
102-
assert(blockShape.size() == offsets.size() &&
103-
blockShape.size() == strides.size());
108+
ArrayRef<std::int64_t> blockShape, Descriptor &desc,
109+
ValueRange offsets) {
110+
assert(blockShape.size() == desc.shape.size());
111+
assert(blockShape.size() == offsets.size());
104112
auto indexTensorType =
105113
RankedTensorType::get(blockShape, builder.getI64Type());
106-
auto ptrType = cast<triton::PointerType>(base.getType());
114+
auto ptrType = cast<triton::PointerType>(desc.base.getType());
107115
auto ptrTensorType = RankedTensorType::get(blockShape, ptrType);
108116

109117
// Generate offsets per dimension
110-
Value ptr = builder.create<triton::SplatOp>(loc, ptrTensorType, base);
118+
Value ptr = builder.create<triton::SplatOp>(loc, ptrTensorType, desc.base);
111119
for (unsigned i = 0; i < blockShape.size(); ++i) {
112120
auto offsetWithRange =
113-
getExpandedOffsetWithRange(builder, loc, blockShape, offsets, i);
121+
getExpandedOffsetWithRange(builder, loc, blockShape, offsets[i], i);
114122

115123
// We must splat strides into the expanded shape not a row for retaining
116124
// the divisibility information given by strides
117125
Value splatStride = builder.create<triton::SplatOp>(
118-
loc, offsetWithRange.getType(), strides[i]);
126+
loc, offsetWithRange.getType(), desc.strides[i]);
119127
Value offsetWithStride =
120128
builder.create<arith::MulIOp>(loc, offsetWithRange, splatStride);
121129
Value broadcasted = builder.create<triton::BroadcastOp>(
@@ -129,19 +137,18 @@ Value generatePtr(OpBuilder &builder, const Location &loc,
129137
return ptr;
130138
}
131139

132-
// Note this has been adapted from RewriteTensorPointer.cpp
133140
Value generateMask(OpBuilder &builder, const Location &loc,
134-
ArrayRef<std::int64_t> blockShape, ValueRange offsets,
135-
ValueRange shape) {
136-
assert(blockShape.size() == shape.size() &&
137-
blockShape.size() == offsets.size());
141+
ArrayRef<std::int64_t> blockShape, Descriptor &desc,
142+
ValueRange offsets) {
143+
assert(blockShape.size() == desc.shape.size());
144+
assert(blockShape.size() == offsets.size());
138145

139146
// Generate mask per dimension
140147
auto maskTensorType = RankedTensorType::get(blockShape, builder.getI1Type());
141148
Value mask;
142149
for (std::size_t i = 0; i < blockShape.size(); ++i) {
143150
auto offsetWithRange =
144-
getExpandedOffsetWithRange(builder, loc, blockShape, offsets, i);
151+
getExpandedOffsetWithRange(builder, loc, blockShape, offsets[i], i);
145152

146153
// Compare with lower bound
147154
Value lowerBound = builder.create<mlir::arith::ConstantIntOp>(
@@ -153,7 +160,7 @@ Value generateMask(OpBuilder &builder, const Location &loc,
153160

154161
// Compare with upper bound
155162
Value splatUpperBound = builder.create<triton::SplatOp>(
156-
loc, offsetWithRange.getType(), shape[i]);
163+
loc, offsetWithRange.getType(), desc.shape[i]);
157164
Value cmpUpper = builder.create<arith::CmpIOp>(
158165
loc, arith::CmpIPredicate::slt, offsetWithRange, splatUpperBound);
159166

@@ -173,20 +180,13 @@ Value generateMask(OpBuilder &builder, const Location &loc,
173180
return mask;
174181
}
175182

176-
// Note this has been adapted from RewriteTensorPointer.cpp. It appears
177-
// to be getting the values used for the masked out elements
178-
Value generateOther(OpBuilder &builder, const Location &loc, Value base,
179-
ArrayRef<std::int64_t> blockShape) {
180-
// Create element attribute
181-
auto elementType = cast<triton::PointerType>(base.getType()).getPointeeType();
182-
auto otherTensorType = RankedTensorType::get(blockShape, elementType);
183-
184-
// Set zero padding value (the default)
185-
TypedAttr attr = builder.getZeroAttr(elementType);
186-
187-
// Create tensor
188-
Value constant = builder.create<arith::ConstantOp>(loc, attr);
189-
return builder.create<triton::SplatOp>(loc, otherTensorType, constant);
183+
Value generateOther(OpBuilder &builder, const Location &loc,
184+
TensorDescType descTy) {
185+
auto scalarTy = descTy.getSignlessBlockType().getElementType();
186+
auto blockTy =
187+
RankedTensorType::get(descTy.getBlockType().getShape(), scalarTy);
188+
auto attr = builder.getZeroAttr(blockTy);
189+
return builder.create<arith::ConstantOp>(loc, attr);
190190
}
191191

192192
SmallVector<mlir::Value> castToI64(OpBuilder &builder,
@@ -201,14 +201,13 @@ struct RewriteMakeTensorDesc : OpConversionPattern<triton::MakeTensorDescOp> {
201201
using OpConversionPattern<triton::MakeTensorDescOp>::OpConversionPattern;
202202

203203
llvm::LogicalResult
204-
matchAndRewrite(triton::MakeTensorDescOp op, OneToNOpAdaptor adaptor,
204+
matchAndRewrite(triton::MakeTensorDescOp op, OpAdaptor adaptor,
205205
ConversionPatternRewriter &rewriter) const override {
206206
SmallVector<mlir::Value> ptrShapeStrides;
207-
// Note that none of these values come from a tensor descriptor so its safe
208-
// to get these directly from the op
209-
llvm::append_values(ptrShapeStrides, op.getBase());
210-
llvm::append_range(ptrShapeStrides, castToI64(rewriter, op.getShape()));
211-
llvm::append_range(ptrShapeStrides, op.getStrides());
207+
llvm::append_values(ptrShapeStrides, adaptor.getBase());
208+
llvm::append_range(ptrShapeStrides,
209+
castToI64(rewriter, adaptor.getShape()));
210+
llvm::append_range(ptrShapeStrides, adaptor.getStrides());
212211
rewriter.replaceOpWithMultiple(op, {ptrShapeStrides});
213212
return mlir::success();
214213
}
@@ -220,26 +219,19 @@ struct RewriteLoadPattern : OpConversionPattern<triton::DescriptorLoadOp> {
220219
llvm::LogicalResult
221220
matchAndRewrite(triton::DescriptorLoadOp op, OneToNOpAdaptor adaptor,
222221
ConversionPatternRewriter &rewriter) const override {
222+
auto loc = op.getLoc();
223223
const auto blockShape = op.getDesc().getType().getBlockType().getShape();
224224
const auto rank = blockShape.size();
225-
assert(adaptor.getDesc().size() == 1 + 2 * rank &&
226-
"Expected tensor descriptors to be "
227-
"broken down into a ptr and "
228-
"`rank` shapes and `rank` strides");
229-
230-
auto base = adaptor.getDesc().front();
231-
auto shape = adaptor.getDesc().slice(1, rank);
232-
auto strides = adaptor.getDesc().slice(1 + rank, rank);
233-
// Note that indices aren't converted so
234-
// we can get them directly here
225+
226+
auto descTy = op.getDesc().getType();
227+
auto desc = unpackDescriptor(descTy, adaptor.getDesc());
235228
auto offsets = castToI64(rewriter, op.getIndices());
236229

237230
auto newLoad = rewriter.replaceOpWithNewOp<triton::LoadOp>(
238-
op,
239-
generatePtr(rewriter, op->getLoc(), blockShape, base, strides, offsets),
240-
generateMask(rewriter, op->getLoc(), blockShape, offsets, shape),
241-
generateOther(rewriter, op->getLoc(), base, blockShape),
242-
triton::CacheModifier::NONE, triton::EvictionPolicy::NORMAL, false);
231+
op, generatePtr(rewriter, loc, blockShape, desc, offsets),
232+
generateMask(rewriter, loc, blockShape, desc, offsets),
233+
generateOther(rewriter, loc, descTy), triton::CacheModifier::NONE,
234+
triton::EvictionPolicy::NORMAL, false);
243235
newLoad->setAttrs(filterSegmentSizes(op->getAttrs()));
244236

245237
return llvm::success();
@@ -252,25 +244,16 @@ struct RewriteStorePattern : OpConversionPattern<triton::DescriptorStoreOp> {
252244
llvm::LogicalResult
253245
matchAndRewrite(triton::DescriptorStoreOp op, OneToNOpAdaptor adaptor,
254246
ConversionPatternRewriter &rewriter) const override {
255-
const auto blockShape = op.getDesc().getType().getBlockType().getShape();
247+
auto loc = op.getLoc();
248+
auto descTy = op.getDesc().getType();
249+
const auto blockShape = descTy.getBlockType().getShape();
256250
const auto rank = blockShape.size();
257-
assert(adaptor.getDesc().size() == 1 + 2 * rank &&
258-
"Expected tensor descriptors to be "
259-
"broken down into a ptr and "
260-
"`rank` shapes and `rank` strides");
261-
262-
auto base = adaptor.getDesc().front();
263-
auto shape = adaptor.getDesc().slice(1, rank);
264-
auto strides = adaptor.getDesc().slice(1 + rank, rank);
265-
// Note that indices aren't converted so
266-
// we can get them directly here
251+
auto desc = unpackDescriptor(descTy, adaptor.getDesc());
267252
auto offsets = castToI64(rewriter, op.getIndices());
268253

269254
auto newStore = rewriter.replaceOpWithNewOp<triton::StoreOp>(
270-
op,
271-
generatePtr(rewriter, op->getLoc(), blockShape, base, strides, offsets),
272-
op.getSrc(),
273-
generateMask(rewriter, op->getLoc(), blockShape, offsets, shape),
255+
op, generatePtr(rewriter, loc, blockShape, desc, offsets), op.getSrc(),
256+
generateMask(rewriter, loc, blockShape, desc, offsets),
274257
triton::CacheModifier::NONE, triton::EvictionPolicy::NORMAL);
275258
newStore->setAttrs(filterSegmentSizes(op->getAttrs()));
276259

@@ -331,9 +314,8 @@ class TritonRewriteTensorDescriptorToPointerPass
331314
// for each dimension, i.e., we create 1+2*rank values. Note that tensor
332315
// descriptors may be signed/unsigned integers whereas pointers should
333316
// always be signless.
334-
auto tensorType = t.getBlockType();
335-
out.push_back(triton::getPointerType(
336-
toSignlessIntegerType(tensorType.getElementType())));
317+
auto tensorType = t.getSignlessBlockType();
318+
out.push_back(triton::getPointerType(tensorType.getElementType()));
337319
out.insert(out.end(), 2 * tensorType.getRank(),
338320
mlir::IntegerType::get(t.getContext(), 64));
339321
return mlir::success();

0 commit comments

Comments
 (0)