Skip to content

Commit ca9d7df

Browse files
committed
Use affine for index and size computations
1 parent eac8c2b commit ca9d7df

File tree

1 file changed

+52
-40
lines changed

1 file changed

+52
-40
lines changed

mlir/lib/Dialect/AMDGPU/Transforms/TransferReadToLoad.cpp

Lines changed: 52 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,15 @@
99
#include "mlir/Dialect/AMDGPU/Transforms/Passes.h"
1010

1111
#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
12+
#include "mlir/Dialect/Affine/IR/AffineOps.h"
13+
#include "mlir/Dialect/Arith/IR/Arith.h"
14+
#include "mlir/Dialect/Arith/Utils/Utils.h"
1215
#include "mlir/Dialect/MemRef/IR/MemRef.h"
16+
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
1317
#include "mlir/Dialect/SCF/IR/SCF.h"
1418
#include "mlir/Dialect/Vector/IR/VectorOps.h"
1519
#include "mlir/IR/BuiltinTypes.h"
20+
#include "mlir/IR/OpDefinition.h"
1621
#include "mlir/IR/PatternMatch.h"
1722
#include "mlir/IR/TypeUtilities.h"
1823
#include "mlir/Pass/Pass.h"
@@ -139,57 +144,64 @@ struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
139144

140145
Location loc = readOp.getLoc();
141146
Value src = readOp.getSource();
142-
MemRefType memRefType = cast<MemRefType>(src.getType());
143-
ArrayRef<int64_t> shape = memRefType.getShape();
144-
145-
Value linearIndex = rewriter.create<arith::ConstantIndexOp>(loc, 0);
146-
Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
147-
Value stride = one;
148-
149-
// Compute the linear index by linearIndex += indices[i] * stride
150-
for (int i = shape.size() - 1; i >= 0; --i) {
151-
Value currentIndex = readOp.getIndices()[i];
152-
Value strideIndexed =
153-
rewriter.create<arith::MulIOp>(loc, currentIndex, stride);
154-
linearIndex =
155-
rewriter.create<arith::AddIOp>(loc, linearIndex, strideIndexed);
156-
157-
if (i == 0)
158-
break;
159-
160-
// Update stride for the next dimension
161-
Value nextStride;
162-
if (shape[i] != ShapedType::kDynamic) {
163-
nextStride = rewriter.create<arith::ConstantIndexOp>(loc, shape[i]);
164-
} else {
165-
nextStride = rewriter.create<memref::DimOp>(loc, src, i);
166-
}
167-
stride = rewriter.create<arith::MulIOp>(loc, stride, nextStride);
168-
}
169147

170-
Value totalSize = one;
171-
for (size_t i = 0; i < shape.size(); ++i) {
172-
Value dimensionSize;
173-
if (shape[i] != ShapedType::kDynamic) {
174-
dimensionSize = rewriter.create<arith::ConstantIndexOp>(loc, shape[i]);
175-
} else {
176-
dimensionSize = rewriter.create<memref::DimOp>(loc, src, i);
177-
}
178-
totalSize = rewriter.create<arith::MulIOp>(loc, totalSize, dimensionSize);
148+
VectorType vectorType = readOp.getVectorType();
149+
int64_t vectorSize = vectorType.getNumElements();
150+
int64_t elementBitWidth = vectorType.getElementTypeBitWidth();
151+
// Value linearIndex = rewriter.create<arith::ConstantIndexOp>(loc, 0);
152+
SmallVector<OpFoldResult> indices = readOp.getIndices();
153+
154+
auto stridedMetadata =
155+
rewriter.create<memref::ExtractStridedMetadataOp>(loc, src);
156+
memref::LinearizedMemRefInfo linearizedInfo;
157+
OpFoldResult linearizedIndices;
158+
std::tie(linearizedInfo, linearizedIndices) =
159+
memref::getLinearizedMemRefOffsetAndSize(
160+
rewriter, loc, elementBitWidth, elementBitWidth,
161+
stridedMetadata.getConstifiedMixedOffset(),
162+
stridedMetadata.getConstifiedMixedSizes(),
163+
stridedMetadata.getConstifiedMixedStrides(), indices);
164+
// OpFoldResult linearIndexSize = linearizedInfo.linearizedSize;
165+
Value linearIndex =
166+
getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices);
167+
168+
// Note below doesn't give the correct result for the linearized size.
169+
// It compute the mutiplied sizes of all dimensions instead of taking
170+
// the maximum of each dimension size * stride.
171+
// TODO(jerryyin): Fix the getLinearizedMemRefOffsetAndSize() function
172+
// Value totalSize = getValueOrCreateConstantIndexOp(
173+
// rewriter, loc, linearizedInfo.linearizedSize);
174+
SmallVector<AffineExpr> productExpressions;
175+
SmallVector<Value> productResults;
176+
unsigned sourceRank =
177+
cast<ShapedType>(readOp.getSource().getType()).getRank();
178+
179+
SmallVector<AffineExpr> symbols(2 * sourceRank);
180+
SmallVector<Value> offsetValues(2 * sourceRank);
181+
bindSymbolsList(rewriter.getContext(), MutableArrayRef{symbols});
182+
for (size_t i = 0; i < sourceRank; ++i) {
183+
unsigned offsetIdx = 2 * i;
184+
productExpressions.push_back(symbols[offsetIdx] * symbols[offsetIdx + 1]);
185+
offsetValues[offsetIdx] = stridedMetadata.getStrides()[i];
186+
offsetValues[offsetIdx + 1] = stridedMetadata.getSizes()[i];
179187
}
180188

189+
AffineMap maxMap = AffineMap::get(
190+
/*dimCount=*/0, /*symbolCount=*/symbols.size(), productExpressions,
191+
rewriter.getContext());
192+
Value totalSize =
193+
rewriter.create<affine::AffineMaxOp>(loc, maxMap, offsetValues);
194+
181195
// delta = bufferSize - linearizedOffset
182-
// 1) check if delta < vectorSize
183-
VectorType vectorType = readOp.getVectorType();
184-
int64_t vectorSize = vectorType.getNumElements();
185196
Value vectorSizeOffset =
186197
rewriter.create<arith::ConstantIndexOp>(loc, vectorSize);
187198
Value delta = rewriter.create<arith::SubIOp>(loc, totalSize, linearIndex);
199+
200+
// 1) check if delta < vectorSize
188201
Value isOutofBounds = rewriter.create<arith::CmpIOp>(
189202
loc, arith::CmpIPredicate::ule, delta, vectorSizeOffset);
190203

191204
// 2) check if (detla(bytes) % (32 / elementBitwidth) != 0)
192-
int64_t elementBitWidth = vectorType.getElementTypeBitWidth();
193205
Value deltaBytes = rewriter.create<arith::MulIOp>(
194206
loc, delta,
195207
rewriter.create<arith::ConstantIndexOp>(loc, elementBitWidth / 8));

0 commit comments

Comments
 (0)