diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp index 898467d573362..910ad1733d03e 100644 --- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp @@ -27,6 +27,7 @@ #include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/IntegerSet.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/Support/LogicalResult.h" #include #define DEBUG_TYPE "affine-utils" @@ -1093,6 +1094,90 @@ void mlir::affine::affineScalarReplace(func::FuncOp f, DominanceInfo &domInfo, op->erase(); } +// Private helper function to transform memref.load with reduced rank. +// This function will modify the indices of the memref.load to match the +// newMemRef. +LogicalResult transformMemRefLoadWithReducedRank( + Operation *op, Value oldMemRef, Value newMemRef, unsigned memRefOperandPos, + ArrayRef extraIndices, ArrayRef extraOperands, + ArrayRef symbolOperands, AffineMap indexRemap) { + unsigned oldMemRefRank = cast(oldMemRef.getType()).getRank(); + unsigned newMemRefRank = cast(newMemRef.getType()).getRank(); + unsigned oldMapNumInputs = oldMemRefRank; + SmallVector oldMapOperands( + op->operand_begin() + memRefOperandPos + 1, + op->operand_begin() + memRefOperandPos + 1 + oldMapNumInputs); + SmallVector oldMemRefOperands; + oldMemRefOperands.assign(oldMapOperands.begin(), oldMapOperands.end()); + SmallVector remapOperands; + remapOperands.reserve(extraOperands.size() + oldMemRefRank + + symbolOperands.size()); + remapOperands.append(extraOperands.begin(), extraOperands.end()); + remapOperands.append(oldMemRefOperands.begin(), oldMemRefOperands.end()); + remapOperands.append(symbolOperands.begin(), symbolOperands.end()); + + SmallVector remapOutputs; + remapOutputs.reserve(oldMemRefRank); + SmallVector affineApplyOps; + + OpBuilder builder(op); + + if (indexRemap && + indexRemap != builder.getMultiDimIdentityMap(indexRemap.getNumDims())) { + // Remapped indices. + for (auto resultExpr : indexRemap.getResults()) { + auto singleResMap = AffineMap::get( + indexRemap.getNumDims(), indexRemap.getNumSymbols(), resultExpr); + auto afOp = builder.create(op->getLoc(), singleResMap, + remapOperands); + remapOutputs.push_back(afOp); + affineApplyOps.push_back(afOp); + } + } else { + // No remapping specified. + remapOutputs.assign(remapOperands.begin(), remapOperands.end()); + } + + SmallVector newMapOperands; + newMapOperands.reserve(newMemRefRank); + + // Prepend 'extraIndices' in 'newMapOperands'. + for (Value extraIndex : extraIndices) { + assert((isValidDim(extraIndex) || isValidSymbol(extraIndex)) && + "invalid memory op index"); + newMapOperands.push_back(extraIndex); + } + + // Append 'remapOutputs' to 'newMapOperands'. + newMapOperands.append(remapOutputs.begin(), remapOutputs.end()); + + // Create new fully composed AffineMap for new op to be created. + assert(newMapOperands.size() == newMemRefRank); + + OperationState state(op->getLoc(), op->getName()); + // Construct the new operation using this memref. + state.operands.reserve(newMapOperands.size() + extraIndices.size()); + state.operands.push_back(newMemRef); + + // Insert the new memref map operands. + state.operands.append(newMapOperands.begin(), newMapOperands.end()); + + state.types.reserve(op->getNumResults()); + for (auto result : op->getResults()) + state.types.push_back(result.getType()); + + // Copy over the attributes from the old operation to the new operation. + for (auto namedAttr : op->getAttrs()) { + state.attributes.push_back(namedAttr); + } + + // Create the new operation. + auto *repOp = builder.create(state); + op->replaceAllUsesWith(repOp); + op->erase(); + + return success(); +} // Perform the replacement in `op`. LogicalResult mlir::affine::replaceAllMemRefUsesWith( Value oldMemRef, Value newMemRef, Operation *op, @@ -1146,8 +1231,19 @@ LogicalResult mlir::affine::replaceAllMemRefUsesWith( // is set. return failure(); } - op->setOperand(memRefOperandPos, newMemRef); - return success(); + + // Check if it is a memref.load + auto memrefLoad = dyn_cast(op); + bool isReductionLike = + indexRemap.getNumResults() < indexRemap.getNumInputs(); + if (!memrefLoad || !isReductionLike) { + op->setOperand(memRefOperandPos, newMemRef); + return success(); + } + + return transformMemRefLoadWithReducedRank( + op, oldMemRef, newMemRef, memRefOperandPos, extraIndices, extraOperands, + symbolOperands, indexRemap); } // Perform index rewrites for the dereferencing op and then replace the op NamedAttribute oldMapAttrPair = diff --git a/mlir/test/Dialect/MemRef/normalize-memrefs.mlir b/mlir/test/Dialect/MemRef/normalize-memrefs.mlir index c7af033a22a2c..11114bcf2b1ab 100644 --- a/mlir/test/Dialect/MemRef/normalize-memrefs.mlir +++ b/mlir/test/Dialect/MemRef/normalize-memrefs.mlir @@ -3,6 +3,10 @@ // This file tests whether the memref type having non-trivial map layouts // are normalized to trivial (identity) layouts. +// CHECK-DAG: #[[$REDUCE_MAP1:.*]] = affine_map<(d0, d1) -> ((d0 mod 2) * 2 + d1 mod 2 + (d0 floordiv 2) * 4 + (d1 floordiv 2) * 8)> +// CHECK-DAG: #[[$REDUCE_MAP2:.*]] = affine_map<(d0, d1) -> (d0 mod 2 + (d1 mod 2) * 2 + (d0 floordiv 2) * 8 + (d1 floordiv 2) * 4)> +// CHECK-DAG: #[[$REDUCE_MAP3:.*]] = affine_map<(d0, d1) -> (d0 * 4 + d1)> + // CHECK-LABEL: func @permute() func.func @permute() { %A = memref.alloc() : memref<64x256xf32, affine_map<(d0, d1) -> (d1, d0)>> @@ -363,3 +367,37 @@ func.func @memref_with_strided_offset(%arg0: tensor<128x512xf32>, %arg1: index, %1 = bufferization.to_tensor %cast : memref<16x512xf32, strided<[?, ?], offset: ?>> return %1 : tensor<16x512xf32> } + +#map0 = affine_map<(i,k) -> (2 * (i mod 2) + (k mod 2) + 4 * (i floordiv 2) + 8 * (k floordiv 2))> +#map1 = affine_map<(k,j) -> ((k mod 2) + 2 * (j mod 2) + 8 * (k floordiv 2) + 4 * (j floordiv 2))> +#map2 = affine_map<(i,j) -> (4 * i + j)> +// CHECK-LABEL: func @memref_load_with_reduction_map +func.func @memref_load_with_reduction_map(%arg0 : memref<4x4xf32,#map2>) -> () { + %0 = memref.alloc() : memref<4x8xf32,#map0> + %1 = memref.alloc() : memref<8x4xf32,#map1> + %2 = memref.alloc() : memref<4x4xf32,#map2> + // CHECK-NOT: memref<4x8xf32> + // CHECK-NOT: memref<8x4xf32> + // CHECK-NOT: memref<4x4xf32> + %cst = arith.constant 3.0 : f32 + %cst0 = arith.constant 0 : index + affine.for %i = 0 to 4 { + affine.for %j = 0 to 8 { + affine.for %k = 0 to 8 { + // CHECK: %[[INDEX0:.*]] = affine.apply #[[$REDUCE_MAP1]](%{{.*}}, %{{.*}}) + // CHECK: memref.load %alloc[%[[INDEX0]]] : memref<32xf32> + %a = memref.load %0[%i, %k] : memref<4x8xf32,#map0> + // CHECK: %[[INDEX1:.*]] = affine.apply #[[$REDUCE_MAP2]](%{{.*}}, %{{.*}}) + // CHECK: memref.load %alloc_0[%[[INDEX1]]] : memref<32xf32> + %b = memref.load %1[%k, %j] :memref<8x4xf32,#map1> + // CHECK: %[[INDEX2:.*]] = affine.apply #[[$REDUCE_MAP3]](%{{.*}}, %{{.*}}) + // CHECK: memref.load %alloc_1[%[[INDEX2]]] : memref<16xf32> + %c = memref.load %2[%i, %j] : memref<4x4xf32,#map2> + %3 = arith.mulf %a, %b : f32 + %4 = arith.addf %3, %c : f32 + affine.store %4, %arg0[%i, %j] : memref<4x4xf32,#map2> + } + } + } + return +} \ No newline at end of file