Skip to content

Commit 189cddf

Browse files
committed
refactor
1 parent 32bfeb5 commit 189cddf

File tree

1 file changed

+14
-14
lines changed

1 file changed

+14
-14
lines changed

mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,18 @@ static Value getTargetMemref(Operation *op) {
113113
.Default([](auto) { return Value{}; });
114114
}
115115

116+
template <typename T>
117+
static void castResult(T oper, T newOper, Location loc,
118+
PatternRewriter &rewriter) {
119+
memref::ExtractStridedMetadataOp stridedMetadata =
120+
rewriter.create<memref::ExtractStridedMetadataOp>(loc, oper);
121+
rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
122+
oper, cast<MemRefType>(oper.getType()), newOper,
123+
/*offset=*/rewriter.getIndexAttr(0),
124+
stridedMetadata.getConstifiedMixedSizes(),
125+
stridedMetadata.getConstifiedMixedStrides());
126+
}
127+
116128
template <typename T>
117129
static void replaceOp(T op, PatternRewriter &rewriter, Value flatMemref,
118130
Value offset) {
@@ -122,25 +134,13 @@ static void replaceOp(T op, PatternRewriter &rewriter, Value flatMemref,
122134
auto newAlloc = rewriter.create<memref::AllocOp>(
123135
loc, cast<MemRefType>(flatMemref.getType()),
124136
oper.getAlignmentAttr());
125-
memref::ExtractStridedMetadataOp stridedMetadata =
126-
rewriter.create<memref::ExtractStridedMetadataOp>(loc, oper);
127-
rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
128-
op, cast<MemRefType>(oper.getType()), newAlloc,
129-
/*offset=*/rewriter.getIndexAttr(0),
130-
stridedMetadata.getConstifiedMixedSizes(),
131-
stridedMetadata.getConstifiedMixedStrides());
137+
castResult(oper, newAlloc, loc, rewriter);
132138
})
133139
.template Case<memref::AllocaOp>([&](auto oper) {
134140
auto newAlloca = rewriter.create<memref::AllocaOp>(
135141
loc, cast<MemRefType>(flatMemref.getType()),
136142
oper.getAlignmentAttr());
137-
memref::ExtractStridedMetadataOp stridedMetadata =
138-
rewriter.create<memref::ExtractStridedMetadataOp>(loc, oper);
139-
rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
140-
op, cast<MemRefType>(oper.getType()), newAlloca,
141-
/*offset=*/rewriter.getIndexAttr(0),
142-
stridedMetadata.getConstifiedMixedSizes(),
143-
stridedMetadata.getConstifiedMixedStrides());
143+
castResult(oper, newAlloca, loc, rewriter);
144144
})
145145
.template Case<memref::LoadOp>([&](auto op) {
146146
auto newLoad = rewriter.create<memref::LoadOp>(

0 commit comments

Comments
 (0)