Skip to content

Commit 4eaeeab

Browse files
authored
[mlir][memref] Fold extract_strided_metadata(cast(x)) into extract_strided_metadata(x) (#164585)
1 parent 9b5bc98 commit 4eaeeab

File tree

4 files changed

+133
-214
lines changed

4 files changed

+133
-214
lines changed

mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1437,6 +1437,13 @@ ExtractStridedMetadataOp::fold(FoldAdaptor adaptor,
14371437
atLeastOneReplacement |= replaceConstantUsesOf(
14381438
builder, getLoc(), getStrides(), getConstifiedMixedStrides());
14391439

1440+
// extract_strided_metadata(cast(x)) -> extract_strided_metadata(x).
1441+
if (auto prev = getSource().getDefiningOp<CastOp>())
1442+
if (isa<MemRefType>(prev.getSource().getType())) {
1443+
getSourceMutable().assign(prev.getSource());
1444+
atLeastOneReplacement = true;
1445+
}
1446+
14401447
return success(atLeastOneReplacement);
14411448
}
14421449

mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp

Lines changed: 0 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -1033,91 +1033,6 @@ class ExtractStridedMetadataOpReinterpretCastFolder
10331033
}
10341034
};
10351035

1036-
/// Replace `base, offset, sizes, strides =
1037-
/// extract_strided_metadata(
1038-
/// cast(src) to dstTy)`
1039-
/// With
1040-
/// ```
1041-
/// base, ... = extract_strided_metadata(src)
1042-
/// offset = !dstTy.srcOffset.isDynamic()
1043-
/// ? dstTy.srcOffset
1044-
/// : extract_strided_metadata(src).offset
1045-
/// sizes = for each srcSize in dstTy.srcSizes:
1046-
/// !srcSize.isDynamic()
1047-
/// ? srcSize
1048-
// : extract_strided_metadata(src).sizes[i]
1049-
/// strides = for each srcStride in dstTy.srcStrides:
1050-
/// !srcStrides.isDynamic()
1051-
/// ? srcStrides
1052-
/// : extract_strided_metadata(src).strides[i]
1053-
/// ```
1054-
///
1055-
/// In other words, consume the `cast` and apply its effects
1056-
/// on the offset, sizes, and strides or compute them directly from `src`.
1057-
class ExtractStridedMetadataOpCastFolder
1058-
: public OpRewritePattern<memref::ExtractStridedMetadataOp> {
1059-
using OpRewritePattern::OpRewritePattern;
1060-
1061-
LogicalResult
1062-
matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp,
1063-
PatternRewriter &rewriter) const override {
1064-
Value source = extractStridedMetadataOp.getSource();
1065-
auto castOp = source.getDefiningOp<memref::CastOp>();
1066-
if (!castOp)
1067-
return failure();
1068-
1069-
Location loc = extractStridedMetadataOp.getLoc();
1070-
// Check if the source is suitable for extract_strided_metadata.
1071-
SmallVector<Type> inferredReturnTypes;
1072-
if (failed(extractStridedMetadataOp.inferReturnTypes(
1073-
rewriter.getContext(), loc, {castOp.getSource()},
1074-
/*attributes=*/{}, /*properties=*/nullptr, /*regions=*/{},
1075-
inferredReturnTypes)))
1076-
return rewriter.notifyMatchFailure(castOp,
1077-
"cast source's type is incompatible");
1078-
1079-
auto memrefType = cast<MemRefType>(source.getType());
1080-
unsigned rank = memrefType.getRank();
1081-
SmallVector<OpFoldResult> results;
1082-
results.resize_for_overwrite(rank * 2 + 2);
1083-
1084-
auto newExtractStridedMetadata = memref::ExtractStridedMetadataOp::create(
1085-
rewriter, loc, castOp.getSource());
1086-
1087-
// Register the base_buffer.
1088-
results[0] = newExtractStridedMetadata.getBaseBuffer();
1089-
1090-
auto getConstantOrValue = [&rewriter](int64_t constant,
1091-
OpFoldResult ofr) -> OpFoldResult {
1092-
return ShapedType::isStatic(constant)
1093-
? OpFoldResult(rewriter.getIndexAttr(constant))
1094-
: ofr;
1095-
};
1096-
1097-
auto [sourceStrides, sourceOffset] = memrefType.getStridesAndOffset();
1098-
assert(sourceStrides.size() == rank && "unexpected number of strides");
1099-
1100-
// Register the new offset.
1101-
results[1] =
1102-
getConstantOrValue(sourceOffset, newExtractStridedMetadata.getOffset());
1103-
1104-
const unsigned sizeStartIdx = 2;
1105-
const unsigned strideStartIdx = sizeStartIdx + rank;
1106-
ArrayRef<int64_t> sourceSizes = memrefType.getShape();
1107-
1108-
SmallVector<OpFoldResult> sizes = newExtractStridedMetadata.getSizes();
1109-
SmallVector<OpFoldResult> strides = newExtractStridedMetadata.getStrides();
1110-
for (unsigned i = 0; i < rank; ++i) {
1111-
results[sizeStartIdx + i] = getConstantOrValue(sourceSizes[i], sizes[i]);
1112-
results[strideStartIdx + i] =
1113-
getConstantOrValue(sourceStrides[i], strides[i]);
1114-
}
1115-
rewriter.replaceOp(extractStridedMetadataOp,
1116-
getValueOrCreateConstantIndexOp(rewriter, loc, results));
1117-
return success();
1118-
}
1119-
};
1120-
11211036
/// Replace `base, offset, sizes, strides = extract_strided_metadata(
11221037
/// memory_space_cast(src) to dstTy)`
11231038
/// with
@@ -1209,7 +1124,6 @@ void memref::populateExpandStridedMetadataPatterns(
12091124
RewriteExtractAlignedPointerAsIndexOfViewLikeOp,
12101125
ExtractStridedMetadataOpReinterpretCastFolder,
12111126
ExtractStridedMetadataOpSubviewFolder,
1212-
ExtractStridedMetadataOpCastFolder,
12131127
ExtractStridedMetadataOpMemorySpaceCastFolder,
12141128
ExtractStridedMetadataOpAssumeAlignmentFolder,
12151129
ExtractStridedMetadataOpExtractStridedMetadataFolder>(
@@ -1226,7 +1140,6 @@ void memref::populateResolveExtractStridedMetadataPatterns(
12261140
ExtractStridedMetadataOpSubviewFolder,
12271141
RewriteExtractAlignedPointerAsIndexOfViewLikeOp,
12281142
ExtractStridedMetadataOpReinterpretCastFolder,
1229-
ExtractStridedMetadataOpCastFolder,
12301143
ExtractStridedMetadataOpMemorySpaceCastFolder,
12311144
ExtractStridedMetadataOpAssumeAlignmentFolder,
12321145
ExtractStridedMetadataOpExtractStridedMetadataFolder>(

mlir/test/Dialect/MemRef/canonicalize.mlir

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -901,6 +901,132 @@ func.func @scope_merge_without_terminator() {
901901

902902
// -----
903903

904+
// Check that we simplify extract_strided_metadata of cast
905+
// when the source of the cast is compatible with what
906+
// `extract_strided_metadata`s accept.
907+
//
908+
// When we apply the transformation the resulting offset, sizes and strides
909+
// should come straight from the inputs of the cast.
910+
// Additionally the folder on extract_strided_metadata should propagate the
911+
// static information.
912+
//
913+
// CHECK-LABEL: func @extract_strided_metadata_of_cast
914+
// CHECK-SAME: %[[ARG:.*]]: memref<3x?xi32, strided<[4, ?], offset: ?>>)
915+
//
916+
// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
917+
// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
918+
// CHECK: %[[BASE:.*]], %[[DYN_OFFSET:.*]], %[[DYN_SIZES:.*]]:2, %[[DYN_STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]]
919+
//
920+
// CHECK: return %[[BASE]], %[[DYN_OFFSET]], %[[C3]], %[[DYN_SIZES]]#1, %[[C4]], %[[DYN_STRIDES]]#1
921+
func.func @extract_strided_metadata_of_cast(
922+
%arg : memref<3x?xi32, strided<[4, ?], offset:?>>)
923+
-> (memref<i32>, index,
924+
index, index,
925+
index, index) {
926+
927+
%cast =
928+
memref.cast %arg :
929+
memref<3x?xi32, strided<[4, ?], offset: ?>> to
930+
memref<?x?xi32, strided<[?, ?], offset: ?>>
931+
932+
%base, %base_offset, %sizes:2, %strides:2 =
933+
memref.extract_strided_metadata %cast:memref<?x?xi32, strided<[?, ?], offset: ?>>
934+
-> memref<i32>, index,
935+
index, index,
936+
index, index
937+
938+
return %base, %base_offset,
939+
%sizes#0, %sizes#1,
940+
%strides#0, %strides#1 :
941+
memref<i32>, index,
942+
index, index,
943+
index, index
944+
}
945+
946+
// -----
947+
948+
// Check that we simplify extract_strided_metadata of cast
949+
// when the source of the cast is compatible with what
950+
// `extract_strided_metadata`s accept.
951+
//
952+
// Same as extract_strided_metadata_of_cast but with constant sizes and strides
953+
// in the destination type.
954+
//
955+
// CHECK-LABEL: func @extract_strided_metadata_of_cast_w_csts
956+
// CHECK-SAME: %[[ARG:.*]]: memref<?x?xi32, strided<[?, ?], offset: ?>>)
957+
//
958+
// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
959+
// CHECK-DAG: %[[C18:.*]] = arith.constant 18 : index
960+
// CHECK-DAG: %[[C25:.*]] = arith.constant 25 : index
961+
// CHECK: %[[BASE:.*]], %[[DYN_OFFSET:.*]], %[[DYN_SIZES:.*]]:2, %[[DYN_STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]]
962+
//
963+
// CHECK: return %[[BASE]], %[[C25]], %[[C4]], %[[DYN_SIZES]]#1, %[[DYN_STRIDES]]#0, %[[C18]]
964+
func.func @extract_strided_metadata_of_cast_w_csts(
965+
%arg : memref<?x?xi32, strided<[?, ?], offset:?>>)
966+
-> (memref<i32>, index,
967+
index, index,
968+
index, index) {
969+
970+
%cast =
971+
memref.cast %arg :
972+
memref<?x?xi32, strided<[?, ?], offset: ?>> to
973+
memref<4x?xi32, strided<[?, 18], offset: 25>>
974+
975+
%base, %base_offset, %sizes:2, %strides:2 =
976+
memref.extract_strided_metadata %cast:memref<4x?xi32, strided<[?, 18], offset: 25>>
977+
-> memref<i32>, index,
978+
index, index,
979+
index, index
980+
981+
return %base, %base_offset,
982+
%sizes#0, %sizes#1,
983+
%strides#0, %strides#1 :
984+
memref<i32>, index,
985+
index, index,
986+
index, index
987+
}
988+
989+
// -----
990+
991+
// Check that we don't simplify extract_strided_metadata of
992+
// cast when the source of the cast is unranked.
993+
// Unranked memrefs cannot feed into extract_strided_metadata operations.
994+
// Note: Technically we could still fold the sizes and strides.
995+
//
996+
// CHECK-LABEL: func @extract_strided_metadata_of_cast_unranked
997+
// CHECK-SAME: %[[ARG:.*]]: memref<*xi32>)
998+
//
999+
// CHECK: %[[CAST:.*]] = memref.cast %[[ARG]] :
1000+
// CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[CAST]]
1001+
//
1002+
// CHECK: return %[[BASE]], %[[OFFSET]], %[[SIZES]]#0, %[[SIZES]]#1, %[[STRIDES]]#0, %[[STRIDES]]#1
1003+
func.func @extract_strided_metadata_of_cast_unranked(
1004+
%arg : memref<*xi32>)
1005+
-> (memref<i32>, index,
1006+
index, index,
1007+
index, index) {
1008+
1009+
%cast =
1010+
memref.cast %arg :
1011+
memref<*xi32> to
1012+
memref<?x?xi32, strided<[?, ?], offset: ?>>
1013+
1014+
%base, %base_offset, %sizes:2, %strides:2 =
1015+
memref.extract_strided_metadata %cast:memref<?x?xi32, strided<[?, ?], offset: ?>>
1016+
-> memref<i32>, index,
1017+
index, index,
1018+
index, index
1019+
1020+
return %base, %base_offset,
1021+
%sizes#0, %sizes#1,
1022+
%strides#0, %strides#1 :
1023+
memref<i32>, index,
1024+
index, index,
1025+
index, index
1026+
}
1027+
1028+
// -----
1029+
9041030
// CHECK-LABEL: func @reinterpret_noop
9051031
// CHECK-SAME: (%[[ARG:.*]]: memref<2x3x4xf32>)
9061032
// CHECK-NEXT: return %[[ARG]]

mlir/test/Dialect/MemRef/expand-strided-metadata.mlir

Lines changed: 0 additions & 127 deletions
Original file line numberDiff line numberDiff line change
@@ -1376,133 +1376,6 @@ func.func @extract_strided_metadata_of_get_global_with_offset()
13761376
memref<i32>, index, index, index, index, index
13771377
}
13781378

1379-
// -----
1380-
1381-
// Check that we simplify extract_strided_metadata of cast
1382-
// when the source of the cast is compatible with what
1383-
// `extract_strided_metadata`s accept.
1384-
//
1385-
// When we apply the transformation the resulting offset, sizes and strides
1386-
// should come straight from the inputs of the cast.
1387-
// Additionally the folder on extract_strided_metadata should propagate the
1388-
// static information.
1389-
//
1390-
// CHECK-LABEL: func @extract_strided_metadata_of_cast
1391-
// CHECK-SAME: %[[ARG:.*]]: memref<3x?xi32, strided<[4, ?], offset: ?>>)
1392-
//
1393-
// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
1394-
// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
1395-
// CHECK: %[[BASE:.*]], %[[DYN_OFFSET:.*]], %[[DYN_SIZES:.*]]:2, %[[DYN_STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]]
1396-
//
1397-
// CHECK: return %[[BASE]], %[[DYN_OFFSET]], %[[C3]], %[[DYN_SIZES]]#1, %[[C4]], %[[DYN_STRIDES]]#1
1398-
func.func @extract_strided_metadata_of_cast(
1399-
%arg : memref<3x?xi32, strided<[4, ?], offset:?>>)
1400-
-> (memref<i32>, index,
1401-
index, index,
1402-
index, index) {
1403-
1404-
%cast =
1405-
memref.cast %arg :
1406-
memref<3x?xi32, strided<[4, ?], offset: ?>> to
1407-
memref<?x?xi32, strided<[?, ?], offset: ?>>
1408-
1409-
%base, %base_offset, %sizes:2, %strides:2 =
1410-
memref.extract_strided_metadata %cast:memref<?x?xi32, strided<[?, ?], offset: ?>>
1411-
-> memref<i32>, index,
1412-
index, index,
1413-
index, index
1414-
1415-
return %base, %base_offset,
1416-
%sizes#0, %sizes#1,
1417-
%strides#0, %strides#1 :
1418-
memref<i32>, index,
1419-
index, index,
1420-
index, index
1421-
}
1422-
1423-
// -----
1424-
1425-
// Check that we simplify extract_strided_metadata of cast
1426-
// when the source of the cast is compatible with what
1427-
// `extract_strided_metadata`s accept.
1428-
//
1429-
// Same as extract_strided_metadata_of_cast but with constant sizes and strides
1430-
// in the destination type.
1431-
//
1432-
// CHECK-LABEL: func @extract_strided_metadata_of_cast_w_csts
1433-
// CHECK-SAME: %[[ARG:.*]]: memref<?x?xi32, strided<[?, ?], offset: ?>>)
1434-
//
1435-
// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
1436-
// CHECK-DAG: %[[C18:.*]] = arith.constant 18 : index
1437-
// CHECK-DAG: %[[C25:.*]] = arith.constant 25 : index
1438-
// CHECK: %[[BASE:.*]], %[[DYN_OFFSET:.*]], %[[DYN_SIZES:.*]]:2, %[[DYN_STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]]
1439-
//
1440-
// CHECK: return %[[BASE]], %[[C25]], %[[C4]], %[[DYN_SIZES]]#1, %[[DYN_STRIDES]]#0, %[[C18]]
1441-
func.func @extract_strided_metadata_of_cast_w_csts(
1442-
%arg : memref<?x?xi32, strided<[?, ?], offset:?>>)
1443-
-> (memref<i32>, index,
1444-
index, index,
1445-
index, index) {
1446-
1447-
%cast =
1448-
memref.cast %arg :
1449-
memref<?x?xi32, strided<[?, ?], offset: ?>> to
1450-
memref<4x?xi32, strided<[?, 18], offset: 25>>
1451-
1452-
%base, %base_offset, %sizes:2, %strides:2 =
1453-
memref.extract_strided_metadata %cast:memref<4x?xi32, strided<[?, 18], offset: 25>>
1454-
-> memref<i32>, index,
1455-
index, index,
1456-
index, index
1457-
1458-
return %base, %base_offset,
1459-
%sizes#0, %sizes#1,
1460-
%strides#0, %strides#1 :
1461-
memref<i32>, index,
1462-
index, index,
1463-
index, index
1464-
}
1465-
1466-
// -----
1467-
1468-
// Check that we don't simplify extract_strided_metadata of
1469-
// cast when the source of the cast is unranked.
1470-
// Unranked memrefs cannot feed into extract_strided_metadata operations.
1471-
// Note: Technically we could still fold the sizes and strides.
1472-
//
1473-
// CHECK-LABEL: func @extract_strided_metadata_of_cast_unranked
1474-
// CHECK-SAME: %[[ARG:.*]]: memref<*xi32>)
1475-
//
1476-
// CHECK: %[[CAST:.*]] = memref.cast %[[ARG]] :
1477-
// CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[CAST]]
1478-
//
1479-
// CHECK: return %[[BASE]], %[[OFFSET]], %[[SIZES]]#0, %[[SIZES]]#1, %[[STRIDES]]#0, %[[STRIDES]]#1
1480-
func.func @extract_strided_metadata_of_cast_unranked(
1481-
%arg : memref<*xi32>)
1482-
-> (memref<i32>, index,
1483-
index, index,
1484-
index, index) {
1485-
1486-
%cast =
1487-
memref.cast %arg :
1488-
memref<*xi32> to
1489-
memref<?x?xi32, strided<[?, ?], offset: ?>>
1490-
1491-
%base, %base_offset, %sizes:2, %strides:2 =
1492-
memref.extract_strided_metadata %cast:memref<?x?xi32, strided<[?, ?], offset: ?>>
1493-
-> memref<i32>, index,
1494-
index, index,
1495-
index, index
1496-
1497-
return %base, %base_offset,
1498-
%sizes#0, %sizes#1,
1499-
%strides#0, %strides#1 :
1500-
memref<i32>, index,
1501-
index, index,
1502-
index, index
1503-
}
1504-
1505-
15061379
// -----
15071380

15081381
memref.global "private" @dynamicShmem : memref<0xf16,3>

0 commit comments

Comments
 (0)