Skip to content

Conversation

@arnab-polymage
Copy link
Contributor

Pass proper dimensional and symbolic operands to the linearized access map of the input memref of the
memref.expand_shape op.

@llvmbot
Copy link
Member

llvmbot commented Feb 26, 2025

@llvm/pr-subscribers-mlir-memref

@llvm/pr-subscribers-mlir

Author: Arnab Dutta (arnab-polymage)

Changes

Pass proper dimensional and symbolic operands to the linearized access map of the input memref of the
memref.expand_shape op.


Full diff: https://github.com/llvm/llvm-project/pull/128844.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp (+2-2)
  • (modified) mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir (+37)
diff --git a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
index 8e927a60087fc..930c5d47839ff 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
@@ -126,11 +126,11 @@ resolveSourceIndicesExpandShape(Location loc, PatternRewriter &rewriter,
     for (int64_t i = 0; i < groupSize; i++)
       dynamicIndices[i] = indices[groups[i]];
 
-    // Supply suffix product results followed by load op indices as operands
+    // Supply load op indices as operands followed by suffix product results
     // to the map.
     SmallVector<OpFoldResult> mapOperands;
-    llvm::append_range(mapOperands, suffixProduct);
     llvm::append_range(mapOperands, dynamicIndices);
+    llvm::append_range(mapOperands, suffixProduct);
 
     // Creating maximally folded and composed affine.apply composes better
     // with other transformations without interleaving canonicalization
diff --git a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
index 327cacf7d9a20..e52dd15b0fdbb 100644
--- a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
+++ b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
@@ -1031,3 +1031,40 @@ func.func @fold_vector_maskedstore_collapse_shape(
 //       CHECK:   %[[IDX:.*]] = affine.apply  #[[$MAP]]()[%[[ARG1]]]
 //       CHECK:   %[[IDX1:.*]] = affine.apply #[[$MAP1]]()[%[[ARG1]]]
 //       CHECK:   vector.maskedstore %[[ARG0]][%[[IDX]], %[[IDX1]]], %[[ARG3]], %[[ARG4]]
+// -----
+
+// CHECK-DAG: #[[$MAP:.*]] = affine_map<(d0, d1)[s0] -> (d0 mod s0)>
+// CHECK-LABEL: fold_expand_shape_dynamic_dim
+func.func @fold_expand_shape_dynamic_dim(%arg0: i64, %arg1: memref<*xf16>) {
+  // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+  %c2 = arith.constant 2 : index
+  // CHECK-DAG:  %[[C2:.*]] = arith.constant 2 : index
+  %cast = memref.cast %arg1 : memref<*xf16> to memref<1x8x?x128xf16>
+  // CHECK: %[[CAST:.*]] = memref.cast
+  %dim = memref.dim %cast, %c2 : memref<1x8x?x128xf16>
+  // CHECK: %[[DIM:.*]] = memref.dim %[[CAST]], %[[C2]]
+  %dim_0 = memref.dim %cast, %c2 : memref<1x8x?x128xf16>
+  %expand_shape = memref.expand_shape %cast [[0], [1], [2, 3], [4]] output_shape [1, 8, 1, %dim_0, 128] : memref<1x8x?x128xf16> into memref<1x8x1x?x128xf16>
+  // CHECK-NOT: memref.expand_shape
+  %0 = arith.index_cast %arg0 : i64 to index
+  // CHECK: %[[IDX:.*]] = arith.index_cast
+  %alloc = memref.alloc(%0) {alignment = 64 : i64} : memref<1x8x4x?x128xf16>
+  affine.for %arg2 = 0 to 8 {
+  // CHECK: affine.for %[[ARG2:.*]] = 0 to 8
+    affine.for %arg3 = 0 to 4 {
+    // CHECK-NEXT: affine.for %[[ARG3:.*]] = 0 to 4
+      affine.for %arg4 = 0 to %0 {
+      // CHECK-NEXT: affine.for %[[ARG4:.*]] = 0 to %[[IDX]]
+        affine.for %arg5 = 0 to 128 {
+          // CHECK-NEXT: affine.for %[[ARG5:.*]] = 0 to 128
+          // CHECK: %[[DIM_0:.*]] = memref.dim %[[CAST]], %[[C2]] : memref<1x8x?x128xf16>
+          // CHECK: %[[APPLY_RES:.*]] = affine.apply #[[$MAP]](%[[ARG4]]
+          %2 = affine.load %expand_shape[0, %arg2, 0, %arg4 mod symbol(%dim), %arg5] : memref<1x8x1x?x128xf16>
+          // CHECK: memref.load %[[CAST]][%[[C0]], %[[ARG2]], %[[APPLY_RES]], %[[ARG5]]] : memref<1x8x?x128xf16>
+          affine.store %2, %alloc[0, %arg2, %arg3, %arg4, %arg5] : memref<1x8x4x?x128xf16>
+        }
+      }
+    }
+  }
+  return
+}

Pass proper dimensional and symbolic operands to the
linearized access map of the input memref of the
memref.expand_shape op.
@arnab-polymage arnab-polymage force-pushed the ornib/expand_shape_fold_bug branch from d824095 to 01074c1 Compare February 26, 2025 08:36
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants