@@ -1098,90 +1098,12 @@ void mlir::affine::affineScalarReplace(func::FuncOp f, DominanceInfo &domInfo,
10981098 op->erase ();
10991099}
11001100
1101- // Private helper function to transform memref.load with reduced rank.
1102- // This function will modify the indices of the memref.load to match the
1103- // newMemRef.
1104- LogicalResult transformMemRefLoadWithReducedRank (
1105- Operation *op, Value oldMemRef, Value newMemRef, unsigned memRefOperandPos,
1106- ArrayRef<Value> extraIndices, ArrayRef<Value> extraOperands,
1107- ArrayRef<Value> symbolOperands, AffineMap indexRemap) {
1108- unsigned oldMemRefRank = cast<MemRefType>(oldMemRef.getType ()).getRank ();
1109- unsigned newMemRefRank = cast<MemRefType>(newMemRef.getType ()).getRank ();
1110- unsigned oldMapNumInputs = oldMemRefRank;
1111- SmallVector<Value, 4 > oldMapOperands (
1112- op->operand_begin () + memRefOperandPos + 1 ,
1113- op->operand_begin () + memRefOperandPos + 1 + oldMapNumInputs);
1114- SmallVector<Value, 4 > oldMemRefOperands;
1115- oldMemRefOperands.assign (oldMapOperands.begin (), oldMapOperands.end ());
1116- SmallVector<Value, 4 > remapOperands;
1117- remapOperands.reserve (extraOperands.size () + oldMemRefRank +
1118- symbolOperands.size ());
1119- remapOperands.append (extraOperands.begin (), extraOperands.end ());
1120- remapOperands.append (oldMemRefOperands.begin (), oldMemRefOperands.end ());
1121- remapOperands.append (symbolOperands.begin (), symbolOperands.end ());
1122-
1123- SmallVector<Value, 4 > remapOutputs;
1124- remapOutputs.reserve (oldMemRefRank);
1125- SmallVector<Value, 4 > affineApplyOps;
1126-
1127- OpBuilder builder (op);
1128-
1129- if (indexRemap &&
1130- indexRemap != builder.getMultiDimIdentityMap (indexRemap.getNumDims ())) {
1131- // Remapped indices.
1132- for (auto resultExpr : indexRemap.getResults ()) {
1133- auto singleResMap = AffineMap::get (
1134- indexRemap.getNumDims (), indexRemap.getNumSymbols (), resultExpr);
1135- auto afOp = builder.create <AffineApplyOp>(op->getLoc (), singleResMap,
1136- remapOperands);
1137- remapOutputs.push_back (afOp);
1138- affineApplyOps.push_back (afOp);
1139- }
1140- } else {
1141- // No remapping specified.
1142- remapOutputs.assign (remapOperands.begin (), remapOperands.end ());
1143- }
1144-
1145- SmallVector<Value, 4 > newMapOperands;
1146- newMapOperands.reserve (newMemRefRank);
1147-
1148- // Prepend 'extraIndices' in 'newMapOperands'.
1149- for (Value extraIndex : extraIndices) {
1150- assert ((isValidDim (extraIndex) || isValidSymbol (extraIndex)) &&
1151- " invalid memory op index" );
1152- newMapOperands.push_back (extraIndex);
1153- }
1154-
1155- // Append 'remapOutputs' to 'newMapOperands'.
1156- newMapOperands.append (remapOutputs.begin (), remapOutputs.end ());
1157-
1158- // Create new fully composed AffineMap for new op to be created.
1159- assert (newMapOperands.size () == newMemRefRank);
1160-
1161- OperationState state (op->getLoc (), op->getName ());
1162- // Construct the new operation using this memref.
1163- state.operands .reserve (newMapOperands.size () + extraIndices.size ());
1164- state.operands .push_back (newMemRef);
1165-
1166- // Insert the new memref map operands.
1167- state.operands .append (newMapOperands.begin (), newMapOperands.end ());
1168-
1169- state.types .reserve (op->getNumResults ());
1170- for (auto result : op->getResults ())
1171- state.types .push_back (result.getType ());
1172-
1173- // Copy over the attributes from the old operation to the new operation.
1174- for (auto namedAttr : op->getAttrs ()) {
1175- state.attributes .push_back (namedAttr);
1176- }
1177-
1178- // Create the new operation.
1179- auto *repOp = builder.create (state);
1180- op->replaceAllUsesWith (repOp);
1181- op->erase ();
1182-
1183- return success ();
1101+ // Checks if `op` is non dereferencing.
1102+ // TODO: This hardcoded check will be removed once the right interface is added.
1103+ static bool isDereferencingOp (Operation *op) {
1104+ return isa<AffineMapAccessInterface, memref::LoadOp, memref::StoreOp>(op);
11841105}
1106+
11851107// Perform the replacement in `op`.
11861108LogicalResult mlir::affine::replaceAllMemRefUsesWith (
11871109 Value oldMemRef, Value newMemRef, Operation *op,
@@ -1216,53 +1138,53 @@ LogicalResult mlir::affine::replaceAllMemRefUsesWith(
12161138 if (usePositions.empty ())
12171139 return success ();
12181140
1219- if (usePositions.size () > 1 ) {
1220- // TODO: extend it for this case when needed (rare).
1221- assert (false && " multiple dereferencing uses in a single op not supported" );
1222- return failure ();
1223- }
1224-
12251141 unsigned memRefOperandPos = usePositions.front ();
12261142
12271143 OpBuilder builder (op);
12281144 // The following checks if op is dereferencing memref and performs the access
12291145 // index rewrites.
1230- auto affMapAccInterface = dyn_cast<AffineMapAccessInterface>(op);
1231- if (!affMapAccInterface) {
1146+ if (!isDereferencingOp (op)) {
12321147 if (!allowNonDereferencingOps) {
12331148 // Failure: memref used in a non-dereferencing context (potentially
12341149 // escapes); no replacement in these cases unless allowNonDereferencingOps
12351150 // is set.
12361151 return failure ();
12371152 }
1153+ for (unsigned pos : usePositions)
1154+ op->setOperand (pos, newMemRef);
1155+ return success ();
1156+ }
12381157
1239- // Check if it is a memref.load
1240- auto memrefLoad = dyn_cast<memref::LoadOp>(op);
1241- bool isReductionLike =
1242- indexRemap.getNumResults () < indexRemap.getNumInputs ();
1243- if (!memrefLoad || !isReductionLike) {
1244- op->setOperand (memRefOperandPos, newMemRef);
1245- return success ();
1246- }
1158+ if (usePositions.size () > 1 ) {
1159+ // TODO: extend it for this case when needed (rare).
1160+ LLVM_DEBUG (llvm::dbgs ()
1161+ << " multiple dereferencing uses in a single op not supported" );
1162+ return failure ();
1163+ }
12471164
1248- return transformMemRefLoadWithReducedRank (
1249- op, oldMemRef, newMemRef, memRefOperandPos, extraIndices, extraOperands,
1250- symbolOperands, indexRemap);
1165+ // Perform index rewrites for the dereferencing op and then replace the op.
1166+ SmallVector<Value, 4 > oldMapOperands;
1167+ AffineMap oldMap;
1168+ unsigned oldMemRefNumIndices = oldMemRefRank;
1169+ auto startIdx = op->operand_begin () + memRefOperandPos + 1 ;
1170+ auto affMapAccInterface = dyn_cast<AffineMapAccessInterface>(op);
1171+ if (affMapAccInterface) {
1172+ // If `op` implements AffineMapAccessInterface, we can get the indices by
1173+ // quering the number of map operands from the operand list from a certain
1174+ // offset (`memRefOperandPos` in this case).
1175+ NamedAttribute oldMapAttrPair =
1176+ affMapAccInterface.getAffineMapAttrForMemRef (oldMemRef);
1177+ oldMap = cast<AffineMapAttr>(oldMapAttrPair.getValue ()).getValue ();
1178+ oldMemRefNumIndices = oldMap.getNumInputs ();
12511179 }
1252- // Perform index rewrites for the dereferencing op and then replace the op
1253- NamedAttribute oldMapAttrPair =
1254- affMapAccInterface.getAffineMapAttrForMemRef (oldMemRef);
1255- AffineMap oldMap = cast<AffineMapAttr>(oldMapAttrPair.getValue ()).getValue ();
1256- unsigned oldMapNumInputs = oldMap.getNumInputs ();
1257- SmallVector<Value, 4 > oldMapOperands (
1258- op->operand_begin () + memRefOperandPos + 1 ,
1259- op->operand_begin () + memRefOperandPos + 1 + oldMapNumInputs);
1180+ oldMapOperands.assign (startIdx, startIdx + oldMemRefNumIndices);
12601181
12611182 // Apply 'oldMemRefOperands = oldMap(oldMapOperands)'.
12621183 SmallVector<Value, 4 > oldMemRefOperands;
12631184 SmallVector<Value, 4 > affineApplyOps;
12641185 oldMemRefOperands.reserve (oldMemRefRank);
1265- if (oldMap != builder.getMultiDimIdentityMap (oldMap.getNumDims ())) {
1186+ if (affMapAccInterface &&
1187+ oldMap != builder.getMultiDimIdentityMap (oldMap.getNumDims ())) {
12661188 for (auto resultExpr : oldMap.getResults ()) {
12671189 auto singleResMap = AffineMap::get (oldMap.getNumDims (),
12681190 oldMap.getNumSymbols (), resultExpr);
@@ -1287,7 +1209,6 @@ LogicalResult mlir::affine::replaceAllMemRefUsesWith(
12871209
12881210 SmallVector<Value, 4 > remapOutputs;
12891211 remapOutputs.reserve (oldMemRefRank);
1290-
12911212 if (indexRemap &&
12921213 indexRemap != builder.getMultiDimIdentityMap (indexRemap.getNumDims ())) {
12931214 // Remapped indices.
@@ -1303,7 +1224,6 @@ LogicalResult mlir::affine::replaceAllMemRefUsesWith(
13031224 // No remapping specified.
13041225 remapOutputs.assign (remapOperands.begin (), remapOperands.end ());
13051226 }
1306-
13071227 SmallVector<Value, 4 > newMapOperands;
13081228 newMapOperands.reserve (newMemRefRank);
13091229
@@ -1338,13 +1258,26 @@ LogicalResult mlir::affine::replaceAllMemRefUsesWith(
13381258 state.operands .push_back (newMemRef);
13391259
13401260 // Insert the new memref map operands.
1341- state.operands .append (newMapOperands.begin (), newMapOperands.end ());
1261+ if (affMapAccInterface) {
1262+ state.operands .append (newMapOperands.begin (), newMapOperands.end ());
1263+ } else {
1264+ // In the case of dereferencing ops not implementing
1265+ // AffineMapAccessInterface, we need to apply the values of `newMapOperands`
1266+ // to the `newMap` to get the correct indices.
1267+ for (unsigned i = 0 ; i < newMemRefRank; i++) {
1268+ state.operands .push_back (builder.create <AffineApplyOp>(
1269+ op->getLoc (),
1270+ AffineMap::get (newMap.getNumDims (), newMap.getNumSymbols (),
1271+ newMap.getResult (i)),
1272+ newMapOperands));
1273+ }
1274+ }
13421275
13431276 // Insert the remaining operands unmodified.
1277+ unsigned oldMapNumInputs = oldMapOperands.size ();
13441278 state.operands .append (op->operand_begin () + memRefOperandPos + 1 +
13451279 oldMapNumInputs,
13461280 op->operand_end ());
1347-
13481281 // Result types don't change. Both memref's are of the same elemental type.
13491282 state.types .reserve (op->getNumResults ());
13501283 for (auto result : op->getResults ())
@@ -1353,7 +1286,9 @@ LogicalResult mlir::affine::replaceAllMemRefUsesWith(
13531286 // Add attribute for 'newMap', other Attributes do not change.
13541287 auto newMapAttr = AffineMapAttr::get (newMap);
13551288 for (auto namedAttr : op->getAttrs ()) {
1356- if (namedAttr.getName () == oldMapAttrPair.getName ())
1289+ if (affMapAccInterface &&
1290+ namedAttr.getName () ==
1291+ affMapAccInterface.getAffineMapAttrForMemRef (oldMemRef).getName ())
13571292 state.attributes .push_back ({namedAttr.getName (), newMapAttr});
13581293 else
13591294 state.attributes .push_back (namedAttr);
@@ -1845,6 +1780,94 @@ LogicalResult mlir::affine::normalizeMemRef(AllocLikeOp allocOp) {
18451780 return success ();
18461781}
18471782
1783+ LogicalResult
1784+ mlir::affine::normalizeMemRef (memref::ReinterpretCastOp reinterpretCastOp) {
1785+ MemRefType memrefType = reinterpretCastOp.getType ();
1786+ AffineMap oldLayoutMap = memrefType.getLayout ().getAffineMap ();
1787+ Value oldMemRef = reinterpretCastOp.getResult ();
1788+
1789+ // If `oldLayoutMap` is identity, `memrefType` is already normalized.
1790+ if (oldLayoutMap.isIdentity ())
1791+ return success ();
1792+
1793+ // Fetch a new memref type after normalizing the old memref to have an
1794+ // identity map layout.
1795+ MemRefType newMemRefType = normalizeMemRefType (memrefType);
1796+ if (newMemRefType == memrefType)
1797+ // `oldLayoutMap` couldn't be transformed to an identity map.
1798+ return failure ();
1799+
1800+ uint64_t newRank = newMemRefType.getRank ();
1801+ SmallVector<Value> mapOperands (oldLayoutMap.getNumDims () +
1802+ oldLayoutMap.getNumSymbols ());
1803+ SmallVector<Value> oldStrides = reinterpretCastOp.getStrides ();
1804+ Location loc = reinterpretCastOp.getLoc ();
1805+ // As `newMemRefType` is normalized, it is unit strided.
1806+ SmallVector<int64_t > newStaticStrides (newRank, 1 );
1807+ SmallVector<int64_t > newStaticOffsets (newRank, 0 );
1808+ ArrayRef<int64_t > oldShape = memrefType.getShape ();
1809+ ValueRange oldSizes = reinterpretCastOp.getSizes ();
1810+ unsigned idx = 0 ;
1811+ SmallVector<int64_t > newStaticSizes;
1812+ OpBuilder b (reinterpretCastOp);
1813+ // Collect the map operands which will be used to compute the new normalized
1814+ // memref shape.
1815+ for (unsigned i = 0 , e = memrefType.getRank (); i < e; i++) {
1816+ if (memrefType.isDynamicDim (i))
1817+ mapOperands[i] =
1818+ b.create <arith::SubIOp>(loc, oldSizes[0 ].getType (), oldSizes[idx++],
1819+ b.create <arith::ConstantIndexOp>(loc, 1 ));
1820+ else
1821+ mapOperands[i] = b.create <arith::ConstantIndexOp>(loc, oldShape[i] - 1 );
1822+ }
1823+ for (unsigned i = 0 , e = oldStrides.size (); i < e; i++)
1824+ mapOperands[memrefType.getRank () + i] = oldStrides[i];
1825+ SmallVector<Value> newSizes;
1826+ ArrayRef<int64_t > newShape = newMemRefType.getShape ();
1827+ // Compute size along all the dimensions of the new normalized memref.
1828+ for (unsigned i = 0 ; i < newRank; i++) {
1829+ if (!newMemRefType.isDynamicDim (i))
1830+ continue ;
1831+ newSizes.push_back (b.create <AffineApplyOp>(
1832+ loc,
1833+ AffineMap::get (oldLayoutMap.getNumDims (), oldLayoutMap.getNumSymbols (),
1834+ oldLayoutMap.getResult (i)),
1835+ mapOperands));
1836+ }
1837+ for (unsigned i = 0 , e = newSizes.size (); i < e; i++) {
1838+ newSizes[i] =
1839+ b.create <arith::AddIOp>(loc, newSizes[i].getType (), newSizes[i],
1840+ b.create <arith::ConstantIndexOp>(loc, 1 ));
1841+ }
1842+ // Create the new reinterpret_cast op.
1843+ auto newReinterpretCast = b.create <memref::ReinterpretCastOp>(
1844+ loc, newMemRefType, reinterpretCastOp.getSource (),
1845+ /* offsets=*/ ValueRange (), newSizes,
1846+ /* strides=*/ ValueRange (),
1847+ /* static_offsets=*/ newStaticOffsets,
1848+ /* static_sizes=*/ newShape,
1849+ /* static_strides=*/ newStaticStrides);
1850+
1851+ // Replace all uses of the old memref.
1852+ if (failed (replaceAllMemRefUsesWith (oldMemRef,
1853+ /* newMemRef=*/ newReinterpretCast,
1854+ /* extraIndices=*/ {},
1855+ /* indexRemap=*/ oldLayoutMap,
1856+ /* extraOperands=*/ {},
1857+ /* symbolOperands=*/ oldStrides,
1858+ /* domOpFilter=*/ nullptr ,
1859+ /* postDomOpFilter=*/ nullptr ,
1860+ /* allowNonDereferencingOps=*/ true ))) {
1861+ // If it failed (due to escapes for example), bail out.
1862+ newReinterpretCast.erase ();
1863+ return failure ();
1864+ }
1865+
1866+ oldMemRef.replaceAllUsesWith (newReinterpretCast);
1867+ reinterpretCastOp.erase ();
1868+ return success ();
1869+ }
1870+
18481871template LogicalResult
18491872mlir::affine::normalizeMemRef<memref::AllocaOp>(memref::AllocaOp op);
18501873template LogicalResult
0 commit comments