Skip to content

Commit 01074c1

Browse files
Fix bug in fold-memref-alias-ops pass
Pass proper dimensional and symbolic operands to the linearized access map of the input memref of the memref.expand_shape op.
1 parent 1a114fa commit 01074c1

File tree

2 files changed

+39
-2
lines changed

2 files changed

+39
-2
lines changed

mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -126,11 +126,11 @@ resolveSourceIndicesExpandShape(Location loc, PatternRewriter &rewriter,
126126
for (int64_t i = 0; i < groupSize; i++)
127127
dynamicIndices[i] = indices[groups[i]];
128128

129-
// Supply suffix product results followed by load op indices as operands
129+
// Supply load op indices as operands followed by suffix product results
130130
// to the map.
131131
SmallVector<OpFoldResult> mapOperands;
132-
llvm::append_range(mapOperands, suffixProduct);
133132
llvm::append_range(mapOperands, dynamicIndices);
133+
llvm::append_range(mapOperands, suffixProduct);
134134

135135
// Creating maximally folded and composed affine.apply composes better
136136
// with other transformations without interleaving canonicalization

mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1031,3 +1031,40 @@ func.func @fold_vector_maskedstore_collapse_shape(
10311031
// CHECK: %[[IDX:.*]] = affine.apply #[[$MAP]]()[%[[ARG1]]]
10321032
// CHECK: %[[IDX1:.*]] = affine.apply #[[$MAP1]]()[%[[ARG1]]]
10331033
// CHECK: vector.maskedstore %[[ARG0]][%[[IDX]], %[[IDX1]]], %[[ARG3]], %[[ARG4]]
1034+
// -----
1035+
1036+
// CHECK-DAG: #[[$MAP:.*]] = affine_map<(d0, d1)[s0] -> (d0 mod s0)>
1037+
// CHECK-LABEL: fold_expand_shape_dynamic_dim
1038+
func.func @fold_expand_shape_dynamic_dim(%arg0: i64, %arg1: memref<*xf16>) {
1039+
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
1040+
%c2 = arith.constant 2 : index
1041+
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
1042+
%cast = memref.cast %arg1 : memref<*xf16> to memref<1x8x?x128xf16>
1043+
// CHECK: %[[CAST:.*]] = memref.cast
1044+
%dim = memref.dim %cast, %c2 : memref<1x8x?x128xf16>
1045+
// CHECK: %[[DIM:.*]] = memref.dim %[[CAST]], %[[C2]]
1046+
%dim_0 = memref.dim %cast, %c2 : memref<1x8x?x128xf16>
1047+
%expand_shape = memref.expand_shape %cast [[0], [1], [2, 3], [4]] output_shape [1, 8, 1, %dim_0, 128] : memref<1x8x?x128xf16> into memref<1x8x1x?x128xf16>
1048+
// CHECK-NOT: memref.expand_shape
1049+
%0 = arith.index_cast %arg0 : i64 to index
1050+
// CHECK: %[[IDX:.*]] = arith.index_cast
1051+
%alloc = memref.alloc(%0) {alignment = 64 : i64} : memref<1x8x4x?x128xf16>
1052+
affine.for %arg2 = 0 to 8 {
1053+
// CHECK: affine.for %[[ARG2:.*]] = 0 to 8
1054+
affine.for %arg3 = 0 to 4 {
1055+
// CHECK-NEXT: affine.for %[[ARG3:.*]] = 0 to 4
1056+
affine.for %arg4 = 0 to %0 {
1057+
// CHECK-NEXT: affine.for %[[ARG4:.*]] = 0 to %[[IDX]]
1058+
affine.for %arg5 = 0 to 128 {
1059+
// CHECK-NEXT: affine.for %[[ARG5:.*]] = 0 to 128
1060+
// CHECK: %[[DIM_0:.*]] = memref.dim %[[CAST]], %[[C2]] : memref<1x8x?x128xf16>
1061+
// CHECK: %[[APPLY_RES:.*]] = affine.apply #[[$MAP]](%[[ARG4]]
1062+
%2 = affine.load %expand_shape[0, %arg2, 0, %arg4 mod symbol(%dim), %arg5] : memref<1x8x1x?x128xf16>
1063+
// CHECK: memref.load %[[CAST]][%[[C0]], %[[ARG2]], %[[APPLY_RES]], %[[ARG5]]] : memref<1x8x?x128xf16>
1064+
affine.store %2, %alloc[0, %arg2, %arg3, %arg4, %arg5] : memref<1x8x4x?x128xf16>
1065+
}
1066+
}
1067+
}
1068+
}
1069+
return
1070+
}

0 commit comments

Comments
 (0)