@@ -44,97 +44,6 @@ using namespace mlir;
4444// Utility functions
4545// ===----------------------------------------------------------------------===//
4646
47- // / Given the 'indices' of a load/store operation where the memref is a result
48- // / of a expand_shape op, returns the indices w.r.t to the source memref of the
49- // / expand_shape op. For example
50- // /
51- // / %0 = ... : memref<12x42xf32>
52- // / %1 = memref.expand_shape %0 [[0, 1], [2]]
53- // / : memref<12x42xf32> into memref<2x6x42xf32>
54- // / %2 = load %1[%i1, %i2, %i3] : memref<2x6x42xf32
55- // /
56- // / could be folded into
57- // /
58- // / %2 = load %0[6 * i1 + i2, %i3] :
59- // / memref<12x42xf32>
60- static LogicalResult resolveSourceIndicesExpandShape (
61- Location loc, PatternRewriter &rewriter,
62- memref::ExpandShapeOp expandShapeOp, ValueRange indices,
63- SmallVectorImpl<Value> &sourceIndices, bool startsInbounds) {
64- SmallVector<OpFoldResult> destShape = expandShapeOp.getMixedOutputShape ();
65-
66- // Traverse all reassociation groups to determine the appropriate indices
67- // corresponding to each one of them post op folding.
68- for (ArrayRef<int64_t > group : expandShapeOp.getReassociationIndices ()) {
69- assert (!group.empty () && " association indices groups cannot be empty" );
70- int64_t groupSize = group.size ();
71- if (groupSize == 1 ) {
72- sourceIndices.push_back (indices[group[0 ]]);
73- continue ;
74- }
75- SmallVector<OpFoldResult> groupBasis =
76- llvm::map_to_vector (group, [&](int64_t d) { return destShape[d]; });
77- SmallVector<Value> groupIndices =
78- llvm::map_to_vector (group, [&](int64_t d) { return indices[d]; });
79- Value collapsedIndex = rewriter.create <affine::AffineLinearizeIndexOp>(
80- loc, groupIndices, groupBasis, /* disjoint=*/ startsInbounds);
81- sourceIndices.push_back (collapsedIndex);
82- }
83- return success ();
84- }
85-
86- // / Given the 'indices' of a load/store operation where the memref is a result
87- // / of a collapse_shape op, returns the indices w.r.t to the source memref of
88- // / the collapse_shape op. For example
89- // /
90- // / %0 = ... : memref<2x6x42xf32>
91- // / %1 = memref.collapse_shape %0 [[0, 1], [2]]
92- // / : memref<2x6x42xf32> into memref<12x42xf32>
93- // / %2 = load %1[%i1, %i2] : memref<12x42xf32>
94- // /
95- // / could be folded into
96- // /
97- // / %2 = load %0[%i1 / 6, %i1 % 6, %i2] :
98- // / memref<2x6x42xf32>
99- static LogicalResult
100- resolveSourceIndicesCollapseShape (Location loc, PatternRewriter &rewriter,
101- memref::CollapseShapeOp collapseShapeOp,
102- ValueRange indices,
103- SmallVectorImpl<Value> &sourceIndices) {
104- // Note: collapse_shape requires a strided memref, we can do this.
105- auto metadata = rewriter.create <memref::ExtractStridedMetadataOp>(
106- loc, collapseShapeOp.getSrc ());
107- SmallVector<OpFoldResult> sourceSizes = metadata.getConstifiedMixedSizes ();
108- for (auto [index, group] :
109- llvm::zip (indices, collapseShapeOp.getReassociationIndices ())) {
110- assert (!group.empty () && " association indices groups cannot be empty" );
111- int64_t groupSize = group.size ();
112-
113- if (groupSize == 1 ) {
114- sourceIndices.push_back (index);
115- continue ;
116- }
117-
118- SmallVector<OpFoldResult> basis =
119- llvm::map_to_vector (group, [&](int64_t d) { return sourceSizes[d]; });
120- auto delinearize = rewriter.create <affine::AffineDelinearizeIndexOp>(
121- loc, index, basis, /* hasOuterBound=*/ true );
122- llvm::append_range (sourceIndices, delinearize.getResults ());
123- }
124- if (collapseShapeOp.getReassociationIndices ().empty ()) {
125- auto zeroAffineMap = rewriter.getConstantAffineMap (0 );
126- int64_t srcRank =
127- cast<MemRefType>(collapseShapeOp.getViewSource ().getType ()).getRank ();
128- OpFoldResult ofr = affine::makeComposedFoldedAffineApply (
129- rewriter, loc, zeroAffineMap, ArrayRef<OpFoldResult>{});
130- for (int64_t i = 0 ; i < srcRank; i++) {
131- sourceIndices.push_back (
132- getValueOrCreateConstantIndexOp (rewriter, loc, ofr));
133- }
134- }
135- return success ();
136- }
137-
13847// / Helpers to access the memref operand for each op.
13948template <typename LoadOrStoreOpTy>
14049static Value getMemRefOperand (LoadOrStoreOpTy op) {
0 commit comments