Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,11 +126,11 @@ resolveSourceIndicesExpandShape(Location loc, PatternRewriter &rewriter,
for (int64_t i = 0; i < groupSize; i++)
dynamicIndices[i] = indices[groups[i]];

// Supply suffix product results followed by load op indices as operands
// Supply load op indices as operands followed by suffix product results
// to the map.
SmallVector<OpFoldResult> mapOperands;
llvm::append_range(mapOperands, suffixProduct);
llvm::append_range(mapOperands, dynamicIndices);
llvm::append_range(mapOperands, suffixProduct);

// Creating maximally folded and composed affine.apply composes better
// with other transformations without interleaving canonicalization
Expand Down
37 changes: 37 additions & 0 deletions mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1031,3 +1031,40 @@ func.func @fold_vector_maskedstore_collapse_shape(
// CHECK: %[[IDX:.*]] = affine.apply #[[$MAP]]()[%[[ARG1]]]
// CHECK: %[[IDX1:.*]] = affine.apply #[[$MAP1]]()[%[[ARG1]]]
// CHECK: vector.maskedstore %[[ARG0]][%[[IDX]], %[[IDX1]]], %[[ARG3]], %[[ARG4]]
// -----

// CHECK-DAG: #[[$MAP:.*]] = affine_map<(d0, d1)[s0] -> (d0 mod s0)>
// CHECK-LABEL: fold_expand_shape_dynamic_dim
func.func @fold_expand_shape_dynamic_dim(%arg0: i64, %arg1: memref<*xf16>) {
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
%c2 = arith.constant 2 : index
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
%cast = memref.cast %arg1 : memref<*xf16> to memref<1x8x?x128xf16>
// CHECK: %[[CAST:.*]] = memref.cast
%dim = memref.dim %cast, %c2 : memref<1x8x?x128xf16>
// CHECK: %[[DIM:.*]] = memref.dim %[[CAST]], %[[C2]]
%dim_0 = memref.dim %cast, %c2 : memref<1x8x?x128xf16>
%expand_shape = memref.expand_shape %cast [[0], [1], [2, 3], [4]] output_shape [1, 8, 1, %dim_0, 128] : memref<1x8x?x128xf16> into memref<1x8x1x?x128xf16>
// CHECK-NOT: memref.expand_shape
%0 = arith.index_cast %arg0 : i64 to index
// CHECK: %[[IDX:.*]] = arith.index_cast
%alloc = memref.alloc(%0) {alignment = 64 : i64} : memref<1x8x4x?x128xf16>
affine.for %arg2 = 0 to 8 {
// CHECK: affine.for %[[ARG2:.*]] = 0 to 8
affine.for %arg3 = 0 to 4 {
// CHECK-NEXT: affine.for %[[ARG3:.*]] = 0 to 4
affine.for %arg4 = 0 to %0 {
// CHECK-NEXT: affine.for %[[ARG4:.*]] = 0 to %[[IDX]]
affine.for %arg5 = 0 to 128 {
// CHECK-NEXT: affine.for %[[ARG5:.*]] = 0 to 128
// CHECK: %[[DIM_0:.*]] = memref.dim %[[CAST]], %[[C2]] : memref<1x8x?x128xf16>
// CHECK: %[[APPLY_RES:.*]] = affine.apply #[[$MAP]](%[[ARG4]]
%2 = affine.load %expand_shape[0, %arg2, 0, %arg4 mod symbol(%dim), %arg5] : memref<1x8x1x?x128xf16>
// CHECK: memref.load %[[CAST]][%[[C0]], %[[ARG2]], %[[APPLY_RES]], %[[ARG5]]] : memref<1x8x?x128xf16>
affine.store %2, %alloc[0, %arg2, %arg3, %arg4, %arg5] : memref<1x8x4x?x128xf16>
}
}
}
}
return
}
Loading