|
9 | 9 | #include "mlir/Dialect/AMDGPU/Transforms/Passes.h" |
10 | 10 |
|
11 | 11 | #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h" |
| 12 | +#include "mlir/Dialect/Affine/IR/AffineOps.h" |
| 13 | +#include "mlir/Dialect/Arith/IR/Arith.h" |
| 14 | +#include "mlir/Dialect/Arith/Utils/Utils.h" |
12 | 15 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
| 16 | +#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" |
13 | 17 | #include "mlir/Dialect/SCF/IR/SCF.h" |
14 | 18 | #include "mlir/Dialect/Vector/IR/VectorOps.h" |
15 | 19 | #include "mlir/IR/BuiltinTypes.h" |
| 20 | +#include "mlir/IR/OpDefinition.h" |
16 | 21 | #include "mlir/IR/PatternMatch.h" |
17 | 22 | #include "mlir/IR/TypeUtilities.h" |
18 | 23 | #include "mlir/Pass/Pass.h" |
@@ -139,57 +144,64 @@ struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> { |
139 | 144 |
|
140 | 145 | Location loc = readOp.getLoc(); |
141 | 146 | Value src = readOp.getSource(); |
142 | | - MemRefType memRefType = cast<MemRefType>(src.getType()); |
143 | | - ArrayRef<int64_t> shape = memRefType.getShape(); |
144 | | - |
145 | | - Value linearIndex = rewriter.create<arith::ConstantIndexOp>(loc, 0); |
146 | | - Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1); |
147 | | - Value stride = one; |
148 | | - |
149 | | - // Compute the linear index by linearIndex += indices[i] * stride |
150 | | - for (int i = shape.size() - 1; i >= 0; --i) { |
151 | | - Value currentIndex = readOp.getIndices()[i]; |
152 | | - Value strideIndexed = |
153 | | - rewriter.create<arith::MulIOp>(loc, currentIndex, stride); |
154 | | - linearIndex = |
155 | | - rewriter.create<arith::AddIOp>(loc, linearIndex, strideIndexed); |
156 | | - |
157 | | - if (i == 0) |
158 | | - break; |
159 | | - |
160 | | - // Update stride for the next dimension |
161 | | - Value nextStride; |
162 | | - if (shape[i] != ShapedType::kDynamic) { |
163 | | - nextStride = rewriter.create<arith::ConstantIndexOp>(loc, shape[i]); |
164 | | - } else { |
165 | | - nextStride = rewriter.create<memref::DimOp>(loc, src, i); |
166 | | - } |
167 | | - stride = rewriter.create<arith::MulIOp>(loc, stride, nextStride); |
168 | | - } |
169 | 147 |
|
170 | | - Value totalSize = one; |
171 | | - for (size_t i = 0; i < shape.size(); ++i) { |
172 | | - Value dimensionSize; |
173 | | - if (shape[i] != ShapedType::kDynamic) { |
174 | | - dimensionSize = rewriter.create<arith::ConstantIndexOp>(loc, shape[i]); |
175 | | - } else { |
176 | | - dimensionSize = rewriter.create<memref::DimOp>(loc, src, i); |
177 | | - } |
178 | | - totalSize = rewriter.create<arith::MulIOp>(loc, totalSize, dimensionSize); |
| 148 | + VectorType vectorType = readOp.getVectorType(); |
| 149 | + int64_t vectorSize = vectorType.getNumElements(); |
| 150 | + int64_t elementBitWidth = vectorType.getElementTypeBitWidth(); |
| 151 | + // Value linearIndex = rewriter.create<arith::ConstantIndexOp>(loc, 0); |
| 152 | + SmallVector<OpFoldResult> indices = readOp.getIndices(); |
| 153 | + |
| 154 | + auto stridedMetadata = |
| 155 | + rewriter.create<memref::ExtractStridedMetadataOp>(loc, src); |
| 156 | + memref::LinearizedMemRefInfo linearizedInfo; |
| 157 | + OpFoldResult linearizedIndices; |
| 158 | + std::tie(linearizedInfo, linearizedIndices) = |
| 159 | + memref::getLinearizedMemRefOffsetAndSize( |
| 160 | + rewriter, loc, elementBitWidth, elementBitWidth, |
| 161 | + stridedMetadata.getConstifiedMixedOffset(), |
| 162 | + stridedMetadata.getConstifiedMixedSizes(), |
| 163 | + stridedMetadata.getConstifiedMixedStrides(), indices); |
| 164 | + // OpFoldResult linearIndexSize = linearizedInfo.linearizedSize; |
| 165 | + Value linearIndex = |
| 166 | + getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices); |
| 167 | + |
| 168 | + // Note below doesn't give the correct result for the linearized size. |
| 169 | + // It compute the mutiplied sizes of all dimensions instead of taking |
| 170 | + // the maximum of each dimension size * stride. |
| 171 | + // TODO(jerryyin): Fix the getLinearizedMemRefOffsetAndSize() function |
| 172 | + // Value totalSize = getValueOrCreateConstantIndexOp( |
| 173 | + // rewriter, loc, linearizedInfo.linearizedSize); |
| 174 | + SmallVector<AffineExpr> productExpressions; |
| 175 | + SmallVector<Value> productResults; |
| 176 | + unsigned sourceRank = |
| 177 | + cast<ShapedType>(readOp.getSource().getType()).getRank(); |
| 178 | + |
| 179 | + SmallVector<AffineExpr> symbols(2 * sourceRank); |
| 180 | + SmallVector<Value> offsetValues(2 * sourceRank); |
| 181 | + bindSymbolsList(rewriter.getContext(), MutableArrayRef{symbols}); |
| 182 | + for (size_t i = 0; i < sourceRank; ++i) { |
| 183 | + unsigned offsetIdx = 2 * i; |
| 184 | + productExpressions.push_back(symbols[offsetIdx] * symbols[offsetIdx + 1]); |
| 185 | + offsetValues[offsetIdx] = stridedMetadata.getStrides()[i]; |
| 186 | + offsetValues[offsetIdx + 1] = stridedMetadata.getSizes()[i]; |
179 | 187 | } |
180 | 188 |
|
| 189 | + AffineMap maxMap = AffineMap::get( |
| 190 | + /*dimCount=*/0, /*symbolCount=*/symbols.size(), productExpressions, |
| 191 | + rewriter.getContext()); |
| 192 | + Value totalSize = |
| 193 | + rewriter.create<affine::AffineMaxOp>(loc, maxMap, offsetValues); |
| 194 | + |
181 | 195 | // delta = bufferSize - linearizedOffset |
182 | | - // 1) check if delta < vectorSize |
183 | | - VectorType vectorType = readOp.getVectorType(); |
184 | | - int64_t vectorSize = vectorType.getNumElements(); |
185 | 196 | Value vectorSizeOffset = |
186 | 197 | rewriter.create<arith::ConstantIndexOp>(loc, vectorSize); |
187 | 198 | Value delta = rewriter.create<arith::SubIOp>(loc, totalSize, linearIndex); |
| 199 | + |
| 200 | + // 1) check if delta < vectorSize |
188 | 201 | Value isOutofBounds = rewriter.create<arith::CmpIOp>( |
189 | 202 | loc, arith::CmpIPredicate::ule, delta, vectorSizeOffset); |
190 | 203 |
|
191 | 204 | // 2) check if (detla(bytes) % (32 / elementBitwidth) != 0) |
192 | | - int64_t elementBitWidth = vectorType.getElementTypeBitWidth(); |
193 | 205 | Value deltaBytes = rewriter.create<arith::MulIOp>( |
194 | 206 | loc, delta, |
195 | 207 | rewriter.create<arith::ConstantIndexOp>(loc, elementBitWidth / 8)); |
|
0 commit comments