@@ -113,6 +113,18 @@ static Value getTargetMemref(Operation *op) {
113113 .Default ([](auto ) { return Value{}; });
114114}
115115
116+ template <typename T>
117+ static void castResult (T oper, T newOper, Location loc,
118+ PatternRewriter &rewriter) {
119+ memref::ExtractStridedMetadataOp stridedMetadata =
120+ rewriter.create <memref::ExtractStridedMetadataOp>(loc, oper);
121+ rewriter.replaceOpWithNewOp <memref::ReinterpretCastOp>(
122+ oper, cast<MemRefType>(oper.getType ()), newOper,
123+ /* offset=*/ rewriter.getIndexAttr (0 ),
124+ stridedMetadata.getConstifiedMixedSizes (),
125+ stridedMetadata.getConstifiedMixedStrides ());
126+ }
127+
116128template <typename T>
117129static void replaceOp (T op, PatternRewriter &rewriter, Value flatMemref,
118130 Value offset) {
@@ -122,25 +134,13 @@ static void replaceOp(T op, PatternRewriter &rewriter, Value flatMemref,
122134 auto newAlloc = rewriter.create <memref::AllocOp>(
123135 loc, cast<MemRefType>(flatMemref.getType ()),
124136 oper.getAlignmentAttr ());
125- memref::ExtractStridedMetadataOp stridedMetadata =
126- rewriter.create <memref::ExtractStridedMetadataOp>(loc, oper);
127- rewriter.replaceOpWithNewOp <memref::ReinterpretCastOp>(
128- op, cast<MemRefType>(oper.getType ()), newAlloc,
129- /* offset=*/ rewriter.getIndexAttr (0 ),
130- stridedMetadata.getConstifiedMixedSizes (),
131- stridedMetadata.getConstifiedMixedStrides ());
137+ castResult (oper, newAlloc, loc, rewriter);
132138 })
133139 .template Case <memref::AllocaOp>([&](auto oper) {
134140 auto newAlloca = rewriter.create <memref::AllocaOp>(
135141 loc, cast<MemRefType>(flatMemref.getType ()),
136142 oper.getAlignmentAttr ());
137- memref::ExtractStridedMetadataOp stridedMetadata =
138- rewriter.create <memref::ExtractStridedMetadataOp>(loc, oper);
139- rewriter.replaceOpWithNewOp <memref::ReinterpretCastOp>(
140- op, cast<MemRefType>(oper.getType ()), newAlloca,
141- /* offset=*/ rewriter.getIndexAttr (0 ),
142- stridedMetadata.getConstifiedMixedSizes (),
143- stridedMetadata.getConstifiedMixedStrides ());
143+ castResult (oper, newAlloca, loc, rewriter);
144144 })
145145 .template Case <memref::LoadOp>([&](auto op) {
146146 auto newLoad = rewriter.create <memref::LoadOp>(
0 commit comments