@@ -1098,90 +1098,12 @@ void mlir::affine::affineScalarReplace(func::FuncOp f, DominanceInfo &domInfo,
10981098 op->erase ();
10991099}
11001100
1101- // Private helper function to transform memref.load with reduced rank.
1102- // This function will modify the indices of the memref.load to match the
1103- // newMemRef.
1104- LogicalResult transformMemRefLoadWithReducedRank (
1105- Operation *op, Value oldMemRef, Value newMemRef, unsigned memRefOperandPos,
1106- ArrayRef<Value> extraIndices, ArrayRef<Value> extraOperands,
1107- ArrayRef<Value> symbolOperands, AffineMap indexRemap) {
1108- unsigned oldMemRefRank = cast<MemRefType>(oldMemRef.getType ()).getRank ();
1109- unsigned newMemRefRank = cast<MemRefType>(newMemRef.getType ()).getRank ();
1110- unsigned oldMapNumInputs = oldMemRefRank;
1111- SmallVector<Value, 4 > oldMapOperands (
1112- op->operand_begin () + memRefOperandPos + 1 ,
1113- op->operand_begin () + memRefOperandPos + 1 + oldMapNumInputs);
1114- SmallVector<Value, 4 > oldMemRefOperands;
1115- oldMemRefOperands.assign (oldMapOperands.begin (), oldMapOperands.end ());
1116- SmallVector<Value, 4 > remapOperands;
1117- remapOperands.reserve (extraOperands.size () + oldMemRefRank +
1118- symbolOperands.size ());
1119- remapOperands.append (extraOperands.begin (), extraOperands.end ());
1120- remapOperands.append (oldMemRefOperands.begin (), oldMemRefOperands.end ());
1121- remapOperands.append (symbolOperands.begin (), symbolOperands.end ());
1122-
1123- SmallVector<Value, 4 > remapOutputs;
1124- remapOutputs.reserve (oldMemRefRank);
1125- SmallVector<Value, 4 > affineApplyOps;
1126-
1127- OpBuilder builder (op);
1128-
1129- if (indexRemap &&
1130- indexRemap != builder.getMultiDimIdentityMap (indexRemap.getNumDims ())) {
1131- // Remapped indices.
1132- for (auto resultExpr : indexRemap.getResults ()) {
1133- auto singleResMap = AffineMap::get (
1134- indexRemap.getNumDims (), indexRemap.getNumSymbols (), resultExpr);
1135- auto afOp = builder.create <AffineApplyOp>(op->getLoc (), singleResMap,
1136- remapOperands);
1137- remapOutputs.push_back (afOp);
1138- affineApplyOps.push_back (afOp);
1139- }
1140- } else {
1141- // No remapping specified.
1142- remapOutputs.assign (remapOperands.begin (), remapOperands.end ());
1143- }
1144-
1145- SmallVector<Value, 4 > newMapOperands;
1146- newMapOperands.reserve (newMemRefRank);
1147-
1148- // Prepend 'extraIndices' in 'newMapOperands'.
1149- for (Value extraIndex : extraIndices) {
1150- assert ((isValidDim (extraIndex) || isValidSymbol (extraIndex)) &&
1151- " invalid memory op index" );
1152- newMapOperands.push_back (extraIndex);
1153- }
1154-
1155- // Append 'remapOutputs' to 'newMapOperands'.
1156- newMapOperands.append (remapOutputs.begin (), remapOutputs.end ());
1157-
1158- // Create new fully composed AffineMap for new op to be created.
1159- assert (newMapOperands.size () == newMemRefRank);
1160-
1161- OperationState state (op->getLoc (), op->getName ());
1162- // Construct the new operation using this memref.
1163- state.operands .reserve (newMapOperands.size () + extraIndices.size ());
1164- state.operands .push_back (newMemRef);
1165-
1166- // Insert the new memref map operands.
1167- state.operands .append (newMapOperands.begin (), newMapOperands.end ());
1168-
1169- state.types .reserve (op->getNumResults ());
1170- for (auto result : op->getResults ())
1171- state.types .push_back (result.getType ());
1172-
1173- // Copy over the attributes from the old operation to the new operation.
1174- for (auto namedAttr : op->getAttrs ()) {
1175- state.attributes .push_back (namedAttr);
1176- }
1177-
1178- // Create the new operation.
1179- auto *repOp = builder.create (state);
1180- op->replaceAllUsesWith (repOp);
1181- op->erase ();
1182-
1183- return success ();
1101+ // Checks if `op` is non dereferencing.
1102+ // TODO: This hardcoded check will be removed once the right interface is added.
1103+ static bool isDereferencingOp (Operation *op) {
1104+ return isa<AffineMapAccessInterface, memref::LoadOp, memref::StoreOp>(op);
11841105}
1106+
11851107// Perform the replacement in `op`.
11861108LogicalResult mlir::affine::replaceAllMemRefUsesWith (
11871109 Value oldMemRef, Value newMemRef, Operation *op,
@@ -1228,41 +1150,44 @@ LogicalResult mlir::affine::replaceAllMemRefUsesWith(
12281150 // The following checks if op is dereferencing memref and performs the access
12291151 // index rewrites.
12301152 auto affMapAccInterface = dyn_cast<AffineMapAccessInterface>(op);
1231- if (!affMapAccInterface ) {
1153+ if (!isDereferencingOp (op) ) {
12321154 if (!allowNonDereferencingOps) {
12331155 // Failure: memref used in a non-dereferencing context (potentially
12341156 // escapes); no replacement in these cases unless allowNonDereferencingOps
12351157 // is set.
12361158 return failure ();
12371159 }
1160+ op->setOperand (memRefOperandPos, newMemRef);
1161+ return success ();
1162+ }
12381163
1239- // Check if it is a memref.load
1240- auto memrefLoad = dyn_cast<memref::LoadOp>(op);
1241- bool isReductionLike =
1242- indexRemap.getNumResults () < indexRemap.getNumInputs ();
1243- if (!memrefLoad || !isReductionLike) {
1244- op->setOperand (memRefOperandPos, newMemRef);
1245- return success ();
1246- }
1247-
1248- return transformMemRefLoadWithReducedRank (
1249- op, oldMemRef, newMemRef, memRefOperandPos, extraIndices, extraOperands,
1250- symbolOperands, indexRemap);
1164+ // Perform index rewrites for the dereferencing op and then replace the op.
1165+ SmallVector<Value, 4 > oldMapOperands;
1166+ AffineMap oldMap;
1167+ unsigned oldMemRefNumIndices = oldMemRefRank;
1168+ if (affMapAccInterface) {
1169+ // If `op` implements AffineMapAccessInterface, we can get the indices by
1170+ // quering the number of map operands from the operand list from a certain
1171+ // offset (`memRefOperandPos` in this case).
1172+ NamedAttribute oldMapAttrPair =
1173+ affMapAccInterface.getAffineMapAttrForMemRef (oldMemRef);
1174+ oldMap = cast<AffineMapAttr>(oldMapAttrPair.getValue ()).getValue ();
1175+ oldMemRefNumIndices = oldMap.getNumInputs ();
1176+ oldMapOperands.assign (op->operand_begin () + memRefOperandPos + 1 ,
1177+ op->operand_begin () + memRefOperandPos + 1 +
1178+ oldMemRefNumIndices);
1179+ } else {
1180+ oldMapOperands.assign (op->operand_begin () + memRefOperandPos + 1 ,
1181+ op->operand_begin () + memRefOperandPos + 1 +
1182+ oldMemRefRank);
12511183 }
1252- // Perform index rewrites for the dereferencing op and then replace the op
1253- NamedAttribute oldMapAttrPair =
1254- affMapAccInterface.getAffineMapAttrForMemRef (oldMemRef);
1255- AffineMap oldMap = cast<AffineMapAttr>(oldMapAttrPair.getValue ()).getValue ();
1256- unsigned oldMapNumInputs = oldMap.getNumInputs ();
1257- SmallVector<Value, 4 > oldMapOperands (
1258- op->operand_begin () + memRefOperandPos + 1 ,
1259- op->operand_begin () + memRefOperandPos + 1 + oldMapNumInputs);
12601184
12611185 // Apply 'oldMemRefOperands = oldMap(oldMapOperands)'.
12621186 SmallVector<Value, 4 > oldMemRefOperands;
12631187 SmallVector<Value, 4 > affineApplyOps;
12641188 oldMemRefOperands.reserve (oldMemRefRank);
1265- if (oldMap != builder.getMultiDimIdentityMap (oldMap.getNumDims ())) {
1189+ if (affMapAccInterface &&
1190+ oldMap != builder.getMultiDimIdentityMap (oldMap.getNumDims ())) {
12661191 for (auto resultExpr : oldMap.getResults ()) {
12671192 auto singleResMap = AffineMap::get (oldMap.getNumDims (),
12681193 oldMap.getNumSymbols (), resultExpr);
@@ -1287,7 +1212,6 @@ LogicalResult mlir::affine::replaceAllMemRefUsesWith(
12871212
12881213 SmallVector<Value, 4 > remapOutputs;
12891214 remapOutputs.reserve (oldMemRefRank);
1290-
12911215 if (indexRemap &&
12921216 indexRemap != builder.getMultiDimIdentityMap (indexRemap.getNumDims ())) {
12931217 // Remapped indices.
@@ -1303,7 +1227,6 @@ LogicalResult mlir::affine::replaceAllMemRefUsesWith(
13031227 // No remapping specified.
13041228 remapOutputs.assign (remapOperands.begin (), remapOperands.end ());
13051229 }
1306-
13071230 SmallVector<Value, 4 > newMapOperands;
13081231 newMapOperands.reserve (newMemRefRank);
13091232
@@ -1338,13 +1261,26 @@ LogicalResult mlir::affine::replaceAllMemRefUsesWith(
13381261 state.operands .push_back (newMemRef);
13391262
13401263 // Insert the new memref map operands.
1341- state.operands .append (newMapOperands.begin (), newMapOperands.end ());
1264+ if (affMapAccInterface) {
1265+ state.operands .append (newMapOperands.begin (), newMapOperands.end ());
1266+ } else {
1267+ // In the case of dereferencing ops not implementing
1268+ // AffineMapAccessInterface, we need to apply the values of `newMapOperands`
1269+ // to the `newMap` to get the correct indices.
1270+ for (unsigned i = 0 ; i < newMemRefRank; i++)
1271+ state.operands .push_back (builder.create <AffineApplyOp>(
1272+ op->getLoc (),
1273+ AffineMap::get (newMap.getNumDims (), newMap.getNumSymbols (),
1274+ newMap.getResult (i)),
1275+ newMapOperands));
1276+ }
13421277
13431278 // Insert the remaining operands unmodified.
1279+ unsigned oldMapNumInputs = oldMapOperands.size ();
1280+
13441281 state.operands .append (op->operand_begin () + memRefOperandPos + 1 +
13451282 oldMapNumInputs,
13461283 op->operand_end ());
1347-
13481284 // Result types don't change. Both memref's are of the same elemental type.
13491285 state.types .reserve (op->getNumResults ());
13501286 for (auto result : op->getResults ())
@@ -1353,7 +1289,9 @@ LogicalResult mlir::affine::replaceAllMemRefUsesWith(
13531289 // Add attribute for 'newMap', other Attributes do not change.
13541290 auto newMapAttr = AffineMapAttr::get (newMap);
13551291 for (auto namedAttr : op->getAttrs ()) {
1356- if (namedAttr.getName () == oldMapAttrPair.getName ())
1292+ if (affMapAccInterface &&
1293+ namedAttr.getName () ==
1294+ affMapAccInterface.getAffineMapAttrForMemRef (oldMemRef).getName ())
13571295 state.attributes .push_back ({namedAttr.getName (), newMapAttr});
13581296 else
13591297 state.attributes .push_back (namedAttr);
@@ -1846,6 +1784,92 @@ LogicalResult mlir::affine::normalizeMemRef(AllocLikeOp *allocOp) {
18461784 return success ();
18471785}
18481786
1787+ LogicalResult
1788+ mlir::affine::normalizeMemRef (memref::ReinterpretCastOp *reinterpretCastOp) {
1789+ MemRefType memrefType = reinterpretCastOp->getType ();
1790+ AffineMap oldLayoutMap = memrefType.getLayout ().getAffineMap ();
1791+ Value oldMemRef = reinterpretCastOp->getResult ();
1792+
1793+ // If `oldLayoutMap` is identity, `memrefType` is already normalized.
1794+ if (oldLayoutMap.isIdentity ())
1795+ return success ();
1796+
1797+ // Fetch a new memref type after normalizing the old memref to have an
1798+ // identity map layout.
1799+ MemRefType newMemRefType = normalizeMemRefType (memrefType);
1800+ if (newMemRefType == memrefType)
1801+ // `oldLayoutMap` couldn't be transformed to an identity map.
1802+ return failure ();
1803+
1804+ uint64_t newRank = newMemRefType.getRank ();
1805+ SmallVector<Value> mapOperands (oldLayoutMap.getNumDims () +
1806+ oldLayoutMap.getNumSymbols ());
1807+ SmallVector<Value> oldStrides = reinterpretCastOp->getStrides ();
1808+ Location loc = reinterpretCastOp->getLoc ();
1809+ // As `newMemRefType` is normalized, it is unit strided.
1810+ SmallVector<int64_t > newStaticStrides (newRank, 1 );
1811+ ArrayRef<int64_t > oldShape = memrefType.getShape ();
1812+ mlir::ValueRange oldSizes = reinterpretCastOp->getSizes ();
1813+ unsigned idx = 0 ;
1814+ SmallVector<int64_t > newStaticSizes;
1815+ OpBuilder b (*reinterpretCastOp);
1816+ // Collectthe map operands which will be used to compute the new normalized
1817+ // memref shape.
1818+ for (unsigned i = 0 , e = memrefType.getRank (); i < e; i++) {
1819+ if (oldShape[i] == ShapedType::kDynamic )
1820+ mapOperands[i] =
1821+ b.create <arith::SubIOp>(loc, oldSizes[0 ].getType (), oldSizes[idx++],
1822+ b.create <arith::ConstantIndexOp>(loc, 1 ));
1823+ else
1824+ mapOperands[i] = b.create <arith::ConstantIndexOp>(loc, oldShape[i] - 1 );
1825+ }
1826+ for (unsigned i = 0 , e = oldStrides.size (); i < e; i++)
1827+ mapOperands[memrefType.getRank () + i] = oldStrides[i];
1828+ SmallVector<Value> newSizes;
1829+ ArrayRef<int64_t > newShape = newMemRefType.getShape ();
1830+ // Compute size along all the dimensions of the new normalized memref.
1831+ for (unsigned i = 0 ; i < newRank; i++) {
1832+ if (newShape[i] != ShapedType::kDynamic )
1833+ continue ;
1834+ newSizes.push_back (b.create <AffineApplyOp>(
1835+ loc,
1836+ AffineMap::get (oldLayoutMap.getNumDims (), oldLayoutMap.getNumSymbols (),
1837+ oldLayoutMap.getResult (i)),
1838+ mapOperands));
1839+ }
1840+ for (unsigned i = 0 , e = newSizes.size (); i < e; i++)
1841+ newSizes[i] =
1842+ b.create <arith::AddIOp>(loc, newSizes[i].getType (), newSizes[i],
1843+ b.create <arith::ConstantIndexOp>(loc, 1 ));
1844+ // Create the new reinterpret_cast op.
1845+ memref::ReinterpretCastOp newReinterpretCast =
1846+ b.create <memref::ReinterpretCastOp>(
1847+ loc, newMemRefType, reinterpretCastOp->getSource (),
1848+ reinterpretCastOp->getOffsets (), newSizes, mlir::ValueRange (),
1849+ /* static_offsets=*/ reinterpretCastOp->getStaticOffsets (),
1850+ /* static_sizes=*/ newShape,
1851+ /* static_strides=*/ newStaticStrides);
1852+
1853+ // Replace all uses of the old memref.
1854+ if (failed (replaceAllMemRefUsesWith (oldMemRef,
1855+ /* newMemRef=*/ newReinterpretCast,
1856+ /* extraIndices=*/ {},
1857+ /* indexRemap=*/ oldLayoutMap,
1858+ /* extraOperands=*/ {},
1859+ /* symbolOperands=*/ oldStrides,
1860+ /* domOpFilter=*/ nullptr ,
1861+ /* postDomOpFilter=*/ nullptr ,
1862+ /* allowNonDereferencingOps=*/ true ))) {
1863+ // If it failed (due to escapes for example), bail out.
1864+ newReinterpretCast->erase ();
1865+ return failure ();
1866+ }
1867+
1868+ oldMemRef.replaceAllUsesWith (newReinterpretCast);
1869+ reinterpretCastOp->erase ();
1870+ return success ();
1871+ }
1872+
18491873template LogicalResult
18501874mlir::affine::normalizeMemRef<memref::AllocaOp>(memref::AllocaOp *op);
18511875template LogicalResult
0 commit comments