|
6 | 6 | // |
7 | 7 | //===----------------------------------------------------------------------===// |
8 | 8 |
|
| 9 | +#include <numeric> |
| 10 | + |
| 11 | +#include "mlir/Dialect/Affine/ViewLikeInterfaceUtils.h" |
9 | 12 | #include "mlir/Dialect/GPU/IR/GPUDialect.h" |
10 | 13 | #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
11 | 14 | #include "mlir/Dialect/Linalg/IR/Linalg.h" |
12 | 15 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
| 16 | +#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" |
13 | 17 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
14 | 18 | #include "mlir/IR/Attributes.h" |
15 | 19 | #include "mlir/IR/Matchers.h" |
@@ -155,9 +159,10 @@ Value flattenMemref(PatternRewriter &rewriter, Location loc, Value srcMemref) { |
155 | 159 | auto srcType = cast<MemRefType>(srcMemref.getType()); |
156 | 160 |
|
157 | 161 | assert(srcType && "Expected a memref type"); |
158 | | - assert(srcType.getRank() == 2 && "Expected a 2D memref"); |
159 | 162 |
|
160 | | - int64_t flatSize = srcType.getShape()[0] * srcType.getShape()[1]; |
| 163 | + auto shapeNd = srcType.getShape(); |
| 164 | + int64_t flatSize = |
| 165 | + std::accumulate(shapeNd.begin(), shapeNd.end(), 1, std::multiplies<>()); |
161 | 166 |
|
162 | 167 | Value offset = rewriter.create<arith::ConstantIndexOp>(loc, 0); |
163 | 168 | Value size = rewriter.create<arith::ConstantIndexOp>(loc, flatSize); |
@@ -193,5 +198,128 @@ bool hasSharedMemSpace(mlir::Value memref) { |
193 | 198 | return false; |
194 | 199 | } |
195 | 200 |
|
| 201 | +std::tuple<SmallVector<Value>, Value> |
| 202 | +computeSubviewOffsets(PatternRewriter &rewriter, Location loc, Value memref) { |
| 203 | + auto fillVal = rewriter.create<arith::ConstantIndexOp>(loc, 0); |
| 204 | + auto origShape = dyn_cast<MemRefType>(memref.getType()).getShape(); |
| 205 | + |
| 206 | + SmallVector<Value> resolvedOffsets(origShape.size(), fillVal); |
| 207 | + |
| 208 | + while (auto subViewOp = memref.getDefiningOp<memref::SubViewOp>()) { |
| 209 | + auto currentOffsets = getAsOpFoldResult(resolvedOffsets); |
| 210 | + resolvedOffsets.clear(); |
| 211 | + |
| 212 | + affine::resolveIndicesIntoOpWithOffsetsAndStrides( |
| 213 | + rewriter, memref.getLoc(), subViewOp.getMixedOffsets(), |
| 214 | + subViewOp.getMixedStrides(), subViewOp.getDroppedDims(), currentOffsets, |
| 215 | + resolvedOffsets); |
| 216 | + memref = subViewOp.getOperand(0); |
| 217 | + } |
| 218 | + |
| 219 | + return std::make_tuple(resolvedOffsets, memref); |
| 220 | +} |
| 221 | + |
| 222 | +SmallVector<OpFoldResult> getMemrefStrides(PatternRewriter &rewriter, |
| 223 | + Location loc, Value memref) { |
| 224 | + auto type = dyn_cast<MemRefType>(memref.getType()); |
| 225 | + |
| 226 | + auto stridedLayout = dyn_cast<StridedLayoutAttr>(type.getLayout()); |
| 227 | + if (stridedLayout) { |
| 228 | + auto strides = stridedLayout.getStrides(); |
| 229 | + return getMixedValues(strides, {}, rewriter); |
| 230 | + } |
| 231 | + |
| 232 | + auto sizes = getMixedValues(type.getShape(), {}, rewriter); |
| 233 | + auto strides = memref::computeStridesIRBlock(loc, rewriter, sizes); |
| 234 | + return strides; |
| 235 | +} |
| 236 | + |
| 237 | +FailureOr<Value> squeezeMemref(PatternRewriter &rewriter, Location loc, |
| 238 | + Value memref, size_t maxDims = 2) { |
| 239 | + auto type = dyn_cast<MemRefType>(memref.getType()); |
| 240 | + auto shape = type.getShape(); |
| 241 | + |
| 242 | + if (shape.size() <= maxDims) |
| 243 | + return memref; |
| 244 | + |
| 245 | + for (size_t i = 0; i < shape.size() - maxDims; i++) |
| 246 | + if (shape[i] != 1) |
| 247 | + return failure(); |
| 248 | + |
| 249 | + auto offsets = |
| 250 | + getMixedValues(SmallVector<int64_t>(shape.size(), 0), {}, rewriter); |
| 251 | + auto sizes = getMixedValues(shape, {}, rewriter); |
| 252 | + auto staticStrides = utils::getStaticStrides(memref).value(); |
| 253 | + auto strides = |
| 254 | + getMixedValues(SmallVector<int64_t>(shape.size(), 1), {}, rewriter); |
| 255 | + |
| 256 | + SmallVector<int64_t> newShape(shape.begin() + shape.size() - maxDims, |
| 257 | + shape.end()); |
| 258 | + SmallVector<int64_t> newStrides( |
| 259 | + staticStrides.begin() + shape.size() - maxDims, staticStrides.end()); |
| 260 | + |
| 261 | + int64_t newOffset = 0; |
| 262 | + if (auto memrefLayout = dyn_cast<StridedLayoutAttr>(type.getLayout())) |
| 263 | + newOffset = memrefLayout.getOffset(); |
| 264 | + |
| 265 | + auto newLayout = StridedLayoutAttr::get( |
| 266 | + rewriter.getContext(), /*offset=*/newOffset, /*strides=*/newStrides); |
| 267 | + MemRefType newMemRefType = MemRefType::get(newShape, type.getElementType(), |
| 268 | + newLayout, type.getMemorySpace()); |
| 269 | + |
| 270 | + auto squeezedSubview = |
| 271 | + rewriter |
| 272 | + .create<memref::SubViewOp>(loc, newMemRefType, memref, offsets, sizes, |
| 273 | + strides) |
| 274 | + .getResult(); |
| 275 | + return squeezedSubview; |
| 276 | +} |
| 277 | + |
| 278 | +LogicalResult maybeSqueezeDims(PatternRewriter &rewriter, |
| 279 | + linalg::LinalgOp linalgOp, size_t maxDims) { |
| 280 | + SmallVector<std::pair<size_t, Value>> newOperands; |
| 281 | + auto operands = linalgOp->getOperands(); |
| 282 | + auto loc = linalgOp.getLoc(); |
| 283 | + |
| 284 | + for (size_t i = 0; i < operands.size(); i++) { |
| 285 | + auto operand = operands[i]; |
| 286 | + auto type = dyn_cast<MemRefType>(operand.getType()); |
| 287 | + if (!type) { |
| 288 | + // maybe should 'continue' here instead and skip non-memref operands? |
| 289 | + // TODO: replace this with 'continue' if such case would appear someday |
| 290 | + return rewriter.notifyMatchFailure( |
| 291 | + linalgOp, "Expect memref operand for XeGPU lowering"); |
| 292 | + } |
| 293 | + |
| 294 | + if (type.getShape().size() <= maxDims) |
| 295 | + continue; |
| 296 | + |
| 297 | + auto res = squeezeMemref(rewriter, loc, operand, maxDims); |
| 298 | + if (failed(res)) { |
| 299 | + return rewriter.notifyMatchFailure( |
| 300 | + linalgOp, "Can't squeeze memref to the desired number of dimensions"); |
| 301 | + } |
| 302 | + |
| 303 | + auto flatSubview = res.value(); |
| 304 | + newOperands.emplace_back(i, flatSubview); |
| 305 | + } |
| 306 | + |
| 307 | + for (auto [i, operand] : newOperands) |
| 308 | + linalgOp->setOperand(i, operand); |
| 309 | + |
| 310 | + return success(); |
| 311 | +} |
| 312 | + |
| 313 | +bool canSqueezeDims(llvm::ArrayRef<int64_t> shape, size_t maxDims) { |
| 314 | + if (shape.size() <= maxDims) |
| 315 | + return true; |
| 316 | + |
| 317 | + for (size_t i = 0; i < shape.size() - maxDims; i++) |
| 318 | + if (shape[i] != 1) |
| 319 | + return false; |
| 320 | + |
| 321 | + return true; |
| 322 | +} |
| 323 | + |
196 | 324 | } // namespace utils |
197 | 325 | } // namespace mlir |
0 commit comments