@@ -1098,90 +1098,12 @@ void mlir::affine::affineScalarReplace(func::FuncOp f, DominanceInfo &domInfo,
10981098 op->erase ();
10991099}
11001100
1101- // Private helper function to transform memref.load with reduced rank.
1102- // This function will modify the indices of the memref.load to match the
1103- // newMemRef.
1104- LogicalResult transformMemRefLoadWithReducedRank (
1105- Operation *op, Value oldMemRef, Value newMemRef, unsigned memRefOperandPos,
1106- ArrayRef<Value> extraIndices, ArrayRef<Value> extraOperands,
1107- ArrayRef<Value> symbolOperands, AffineMap indexRemap) {
1108- unsigned oldMemRefRank = cast<MemRefType>(oldMemRef.getType ()).getRank ();
1109- unsigned newMemRefRank = cast<MemRefType>(newMemRef.getType ()).getRank ();
1110- unsigned oldMapNumInputs = oldMemRefRank;
1111- SmallVector<Value, 4 > oldMapOperands (
1112- op->operand_begin () + memRefOperandPos + 1 ,
1113- op->operand_begin () + memRefOperandPos + 1 + oldMapNumInputs);
1114- SmallVector<Value, 4 > oldMemRefOperands;
1115- oldMemRefOperands.assign (oldMapOperands.begin (), oldMapOperands.end ());
1116- SmallVector<Value, 4 > remapOperands;
1117- remapOperands.reserve (extraOperands.size () + oldMemRefRank +
1118- symbolOperands.size ());
1119- remapOperands.append (extraOperands.begin (), extraOperands.end ());
1120- remapOperands.append (oldMemRefOperands.begin (), oldMemRefOperands.end ());
1121- remapOperands.append (symbolOperands.begin (), symbolOperands.end ());
1122-
1123- SmallVector<Value, 4 > remapOutputs;
1124- remapOutputs.reserve (oldMemRefRank);
1125- SmallVector<Value, 4 > affineApplyOps;
1126-
1127- OpBuilder builder (op);
1128-
1129- if (indexRemap &&
1130- indexRemap != builder.getMultiDimIdentityMap (indexRemap.getNumDims ())) {
1131- // Remapped indices.
1132- for (auto resultExpr : indexRemap.getResults ()) {
1133- auto singleResMap = AffineMap::get (
1134- indexRemap.getNumDims (), indexRemap.getNumSymbols (), resultExpr);
1135- auto afOp = builder.create <AffineApplyOp>(op->getLoc (), singleResMap,
1136- remapOperands);
1137- remapOutputs.push_back (afOp);
1138- affineApplyOps.push_back (afOp);
1139- }
1140- } else {
1141- // No remapping specified.
1142- remapOutputs.assign (remapOperands.begin (), remapOperands.end ());
1143- }
1144-
1145- SmallVector<Value, 4 > newMapOperands;
1146- newMapOperands.reserve (newMemRefRank);
1147-
1148- // Prepend 'extraIndices' in 'newMapOperands'.
1149- for (Value extraIndex : extraIndices) {
1150- assert ((isValidDim (extraIndex) || isValidSymbol (extraIndex)) &&
1151- " invalid memory op index" );
1152- newMapOperands.push_back (extraIndex);
1153- }
1154-
1155- // Append 'remapOutputs' to 'newMapOperands'.
1156- newMapOperands.append (remapOutputs.begin (), remapOutputs.end ());
1157-
1158- // Create new fully composed AffineMap for new op to be created.
1159- assert (newMapOperands.size () == newMemRefRank);
1160-
1161- OperationState state (op->getLoc (), op->getName ());
1162- // Construct the new operation using this memref.
1163- state.operands .reserve (newMapOperands.size () + extraIndices.size ());
1164- state.operands .push_back (newMemRef);
1165-
1166- // Insert the new memref map operands.
1167- state.operands .append (newMapOperands.begin (), newMapOperands.end ());
1168-
1169- state.types .reserve (op->getNumResults ());
1170- for (auto result : op->getResults ())
1171- state.types .push_back (result.getType ());
1172-
1173- // Copy over the attributes from the old operation to the new operation.
1174- for (auto namedAttr : op->getAttrs ()) {
1175- state.attributes .push_back (namedAttr);
1176- }
1177-
1178- // Create the new operation.
1179- auto *repOp = builder.create (state);
1180- op->replaceAllUsesWith (repOp);
1181- op->erase ();
1182-
1183- return success ();
1101+ // Checks if `op` is non dereferencing.
1102+ // TODO: This hardcoded check will be removed once the right interface is added.
1103+ static bool isDereferencingOp (Operation *op) {
1104+ return isa<AffineMapAccessInterface, memref::LoadOp, memref::StoreOp>(op);
11841105}
1106+
11851107// Perform the replacement in `op`.
11861108LogicalResult mlir::affine::replaceAllMemRefUsesWith (
11871109 Value oldMemRef, Value newMemRef, Operation *op,
@@ -1216,53 +1138,57 @@ LogicalResult mlir::affine::replaceAllMemRefUsesWith(
12161138 if (usePositions.empty ())
12171139 return success ();
12181140
1219- if (usePositions.size () > 1 ) {
1220- // TODO: extend it for this case when needed (rare).
1221- assert (false && " multiple dereferencing uses in a single op not supported" );
1222- return failure ();
1223- }
1224-
12251141 unsigned memRefOperandPos = usePositions.front ();
12261142
12271143 OpBuilder builder (op);
12281144 // The following checks if op is dereferencing memref and performs the access
12291145 // index rewrites.
12301146 auto affMapAccInterface = dyn_cast<AffineMapAccessInterface>(op);
1231- if (!affMapAccInterface ) {
1147+ if (!isDereferencingOp (op) ) {
12321148 if (!allowNonDereferencingOps) {
12331149 // Failure: memref used in a non-dereferencing context (potentially
12341150 // escapes); no replacement in these cases unless allowNonDereferencingOps
12351151 // is set.
12361152 return failure ();
12371153 }
1154+ for (unsigned pos : usePositions)
1155+ op->setOperand (pos, newMemRef);
1156+ return success ();
1157+ }
12381158
1239- // Check if it is a memref.load
1240- auto memrefLoad = dyn_cast<memref::LoadOp>(op);
1241- bool isReductionLike =
1242- indexRemap.getNumResults () < indexRemap.getNumInputs ();
1243- if (!memrefLoad || !isReductionLike) {
1244- op->setOperand (memRefOperandPos, newMemRef);
1245- return success ();
1246- }
1159+ if (usePositions.size () > 1 ) {
1160+ // TODO: extend it for this case when needed (rare).
1161+ assert (false && " multiple dereferencing uses in a single op not supported" );
1162+ return failure ();
1163+ }
12471164
1248- return transformMemRefLoadWithReducedRank (
1249- op, oldMemRef, newMemRef, memRefOperandPos, extraIndices, extraOperands,
1250- symbolOperands, indexRemap);
1165+ // Perform index rewrites for the dereferencing op and then replace the op.
1166+ SmallVector<Value, 4 > oldMapOperands;
1167+ AffineMap oldMap;
1168+ unsigned oldMemRefNumIndices = oldMemRefRank;
1169+ if (affMapAccInterface) {
1170+ // If `op` implements AffineMapAccessInterface, we can get the indices by
1171+ // quering the number of map operands from the operand list from a certain
1172+ // offset (`memRefOperandPos` in this case).
1173+ NamedAttribute oldMapAttrPair =
1174+ affMapAccInterface.getAffineMapAttrForMemRef (oldMemRef);
1175+ oldMap = cast<AffineMapAttr>(oldMapAttrPair.getValue ()).getValue ();
1176+ oldMemRefNumIndices = oldMap.getNumInputs ();
1177+ oldMapOperands.assign (op->operand_begin () + memRefOperandPos + 1 ,
1178+ op->operand_begin () + memRefOperandPos + 1 +
1179+ oldMemRefNumIndices);
1180+ } else {
1181+ oldMapOperands.assign (op->operand_begin () + memRefOperandPos + 1 ,
1182+ op->operand_begin () + memRefOperandPos + 1 +
1183+ oldMemRefRank);
12511184 }
1252- // Perform index rewrites for the dereferencing op and then replace the op
1253- NamedAttribute oldMapAttrPair =
1254- affMapAccInterface.getAffineMapAttrForMemRef (oldMemRef);
1255- AffineMap oldMap = cast<AffineMapAttr>(oldMapAttrPair.getValue ()).getValue ();
1256- unsigned oldMapNumInputs = oldMap.getNumInputs ();
1257- SmallVector<Value, 4 > oldMapOperands (
1258- op->operand_begin () + memRefOperandPos + 1 ,
1259- op->operand_begin () + memRefOperandPos + 1 + oldMapNumInputs);
12601185
12611186 // Apply 'oldMemRefOperands = oldMap(oldMapOperands)'.
12621187 SmallVector<Value, 4 > oldMemRefOperands;
12631188 SmallVector<Value, 4 > affineApplyOps;
12641189 oldMemRefOperands.reserve (oldMemRefRank);
1265- if (oldMap != builder.getMultiDimIdentityMap (oldMap.getNumDims ())) {
1190+ if (affMapAccInterface &&
1191+ oldMap != builder.getMultiDimIdentityMap (oldMap.getNumDims ())) {
12661192 for (auto resultExpr : oldMap.getResults ()) {
12671193 auto singleResMap = AffineMap::get (oldMap.getNumDims (),
12681194 oldMap.getNumSymbols (), resultExpr);
@@ -1287,7 +1213,6 @@ LogicalResult mlir::affine::replaceAllMemRefUsesWith(
12871213
12881214 SmallVector<Value, 4 > remapOutputs;
12891215 remapOutputs.reserve (oldMemRefRank);
1290-
12911216 if (indexRemap &&
12921217 indexRemap != builder.getMultiDimIdentityMap (indexRemap.getNumDims ())) {
12931218 // Remapped indices.
@@ -1303,7 +1228,6 @@ LogicalResult mlir::affine::replaceAllMemRefUsesWith(
13031228 // No remapping specified.
13041229 remapOutputs.assign (remapOperands.begin (), remapOperands.end ());
13051230 }
1306-
13071231 SmallVector<Value, 4 > newMapOperands;
13081232 newMapOperands.reserve (newMemRefRank);
13091233
@@ -1338,13 +1262,26 @@ LogicalResult mlir::affine::replaceAllMemRefUsesWith(
13381262 state.operands .push_back (newMemRef);
13391263
13401264 // Insert the new memref map operands.
1341- state.operands .append (newMapOperands.begin (), newMapOperands.end ());
1265+ if (affMapAccInterface) {
1266+ state.operands .append (newMapOperands.begin (), newMapOperands.end ());
1267+ } else {
1268+ // In the case of dereferencing ops not implementing
1269+ // AffineMapAccessInterface, we need to apply the values of `newMapOperands`
1270+ // to the `newMap` to get the correct indices.
1271+ for (unsigned i = 0 ; i < newMemRefRank; i++)
1272+ state.operands .push_back (builder.create <AffineApplyOp>(
1273+ op->getLoc (),
1274+ AffineMap::get (newMap.getNumDims (), newMap.getNumSymbols (),
1275+ newMap.getResult (i)),
1276+ newMapOperands));
1277+ }
13421278
13431279 // Insert the remaining operands unmodified.
1280+ unsigned oldMapNumInputs = oldMapOperands.size ();
1281+
13441282 state.operands .append (op->operand_begin () + memRefOperandPos + 1 +
13451283 oldMapNumInputs,
13461284 op->operand_end ());
1347-
13481285 // Result types don't change. Both memref's are of the same elemental type.
13491286 state.types .reserve (op->getNumResults ());
13501287 for (auto result : op->getResults ())
@@ -1353,7 +1290,9 @@ LogicalResult mlir::affine::replaceAllMemRefUsesWith(
13531290 // Add attribute for 'newMap', other Attributes do not change.
13541291 auto newMapAttr = AffineMapAttr::get (newMap);
13551292 for (auto namedAttr : op->getAttrs ()) {
1356- if (namedAttr.getName () == oldMapAttrPair.getName ())
1293+ if (affMapAccInterface &&
1294+ namedAttr.getName () ==
1295+ affMapAccInterface.getAffineMapAttrForMemRef (oldMemRef).getName ())
13571296 state.attributes .push_back ({namedAttr.getName (), newMapAttr});
13581297 else
13591298 state.attributes .push_back (namedAttr);
@@ -1846,6 +1785,95 @@ LogicalResult mlir::affine::normalizeMemRef(AllocLikeOp *allocOp) {
18461785 return success ();
18471786}
18481787
1788+ LogicalResult
1789+ mlir::affine::normalizeMemRef (memref::ReinterpretCastOp *reinterpretCastOp) {
1790+ MemRefType memrefType = reinterpretCastOp->getType ();
1791+ AffineMap oldLayoutMap = memrefType.getLayout ().getAffineMap ();
1792+ Value oldMemRef = reinterpretCastOp->getResult ();
1793+
1794+ // If `oldLayoutMap` is identity, `memrefType` is already normalized.
1795+ if (oldLayoutMap.isIdentity ())
1796+ return success ();
1797+
1798+ // Fetch a new memref type after normalizing the old memref to have an
1799+ // identity map layout.
1800+ MemRefType newMemRefType = normalizeMemRefType (memrefType);
1801+ newMemRefType.dump ();
1802+ if (newMemRefType == memrefType)
1803+ // `oldLayoutMap` couldn't be transformed to an identity map.
1804+ return failure ();
1805+
1806+ uint64_t newRank = newMemRefType.getRank ();
1807+ SmallVector<Value> mapOperands (oldLayoutMap.getNumDims () +
1808+ oldLayoutMap.getNumSymbols ());
1809+ SmallVector<Value> oldStrides = reinterpretCastOp->getStrides ();
1810+ Location loc = reinterpretCastOp->getLoc ();
1811+ // As `newMemRefType` is normalized, it is unit strided.
1812+ SmallVector<int64_t > newStaticStrides (newRank, 1 );
1813+ SmallVector<int64_t > newStaticOffsets (newRank, 0 );
1814+ ArrayRef<int64_t > oldShape = memrefType.getShape ();
1815+ mlir::ValueRange oldSizes = reinterpretCastOp->getSizes ();
1816+ unsigned idx = 0 ;
1817+ SmallVector<int64_t > newStaticSizes;
1818+ OpBuilder b (*reinterpretCastOp);
1819+ // Collectthe map operands which will be used to compute the new normalized
1820+ // memref shape.
1821+ for (unsigned i = 0 , e = memrefType.getRank (); i < e; i++) {
1822+ if (oldShape[i] == ShapedType::kDynamic )
1823+ mapOperands[i] =
1824+ b.create <arith::SubIOp>(loc, oldSizes[0 ].getType (), oldSizes[idx++],
1825+ b.create <arith::ConstantIndexOp>(loc, 1 ));
1826+ else
1827+ mapOperands[i] = b.create <arith::ConstantIndexOp>(loc, oldShape[i] - 1 );
1828+ }
1829+ for (unsigned i = 0 , e = oldStrides.size (); i < e; i++)
1830+ mapOperands[memrefType.getRank () + i] = oldStrides[i];
1831+ SmallVector<Value> newSizes;
1832+ ArrayRef<int64_t > newShape = newMemRefType.getShape ();
1833+ // Compute size along all the dimensions of the new normalized memref.
1834+ for (unsigned i = 0 ; i < newRank; i++) {
1835+ if (newMemRefType.isDynamicDim (i))
1836+ continue ;
1837+ newSizes.push_back (b.create <AffineApplyOp>(
1838+ loc,
1839+ AffineMap::get (oldLayoutMap.getNumDims (), oldLayoutMap.getNumSymbols (),
1840+ oldLayoutMap.getResult (i)),
1841+ mapOperands));
1842+ }
1843+ for (unsigned i = 0 , e = newSizes.size (); i < e; i++)
1844+ newSizes[i] =
1845+ b.create <arith::AddIOp>(loc, newSizes[i].getType (), newSizes[i],
1846+ b.create <arith::ConstantIndexOp>(loc, 1 ));
1847+ // Create the new reinterpret_cast op.
1848+ memref::ReinterpretCastOp newReinterpretCast =
1849+ b.create <memref::ReinterpretCastOp>(
1850+ loc, newMemRefType, reinterpretCastOp->getSource (),
1851+ /* offsets=*/ mlir::ValueRange (), newSizes,
1852+ /* strides=*/ mlir::ValueRange (),
1853+ /* static_offsets=*/ newStaticOffsets,
1854+ /* static_sizes=*/ newShape,
1855+ /* static_strides=*/ newStaticStrides);
1856+
1857+ // Replace all uses of the old memref.
1858+ if (failed (replaceAllMemRefUsesWith (oldMemRef,
1859+ /* newMemRef=*/ newReinterpretCast,
1860+ /* extraIndices=*/ {},
1861+ /* indexRemap=*/ oldLayoutMap,
1862+ /* extraOperands=*/ {},
1863+ /* symbolOperands=*/ oldStrides,
1864+ /* domOpFilter=*/ nullptr ,
1865+ /* postDomOpFilter=*/ nullptr ,
1866+ /* allowNonDereferencingOps=*/ true ))) {
1867+ // If it failed (due to escapes for example), bail out.
1868+ newReinterpretCast->erase ();
1869+ return failure ();
1870+ }
1871+
1872+ oldMemRef.replaceAllUsesWith (newReinterpretCast);
1873+ reinterpretCastOp->erase ();
1874+ return success ();
1875+ }
1876+
18491877template LogicalResult
18501878mlir::affine::normalizeMemRef<memref::AllocaOp>(memref::AllocaOp *op);
18511879template LogicalResult
0 commit comments