@@ -1031,3 +1031,40 @@ func.func @fold_vector_maskedstore_collapse_shape(
10311031// CHECK: %[[IDX:.*]] = affine.apply #[[$MAP]]()[%[[ARG1]]]
10321032// CHECK: %[[IDX1:.*]] = affine.apply #[[$MAP1]]()[%[[ARG1]]]
10331033// CHECK: vector.maskedstore %[[ARG0]][%[[IDX]], %[[IDX1]]], %[[ARG3]], %[[ARG4]]
1034+ // -----
1035+
1036+ // CHECK-DAG: #[[$MAP:.*]] = affine_map<(d0, d1)[s0] -> (d0 mod s0)>
1037+ // CHECK-LABEL: fold_expand_shape_dynamic_dim
1038+ func.func @fold_expand_shape_dynamic_dim (%arg0: i64 , %arg1: memref <*xf16 >) {
1039+ // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
1040+ %c2 = arith.constant 2 : index
1041+ // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
1042+ %cast = memref.cast %arg1 : memref <*xf16 > to memref <1 x8 x?x128 xf16 >
1043+ // CHECK: %[[CAST:.*]] = memref.cast
1044+ %dim = memref.dim %cast , %c2 : memref <1 x8 x?x128 xf16 >
1045+ // CHECK: %[[DIM:.*]] = memref.dim %[[CAST]], %[[C2]]
1046+ %dim_0 = memref.dim %cast , %c2 : memref <1 x8 x?x128 xf16 >
1047+ %expand_shape = memref.expand_shape %cast [[0 ], [1 ], [2 , 3 ], [4 ]] output_shape [1 , 8 , 1 , %dim_0 , 128 ] : memref <1 x8 x?x128 xf16 > into memref <1 x8 x1 x?x128 xf16 >
1048+ // CHECK-NOT: memref.expand_shape
1049+ %0 = arith.index_cast %arg0 : i64 to index
1050+ // CHECK: %[[IDX:.*]] = arith.index_cast
1051+ %alloc = memref.alloc (%0 ) {alignment = 64 : i64 } : memref <1 x8 x4 x?x128 xf16 >
1052+ affine.for %arg2 = 0 to 8 {
1053+ // CHECK: affine.for %[[ARG2:.*]] = 0 to 8
1054+ affine.for %arg3 = 0 to 4 {
1055+ // CHECK-NEXT: affine.for %[[ARG3:.*]] = 0 to 4
1056+ affine.for %arg4 = 0 to %0 {
1057+ // CHECK-NEXT: affine.for %[[ARG4:.*]] = 0 to %[[IDX]]
1058+ affine.for %arg5 = 0 to 128 {
1059+ // CHECK-NEXT: affine.for %[[ARG5:.*]] = 0 to 128
1060+ // CHECK: %[[DIM_0:.*]] = memref.dim %[[CAST]], %[[C2]] : memref<1x8x?x128xf16>
1061+ // CHECK: %[[APPLY_RES:.*]] = affine.apply #[[$MAP]](%[[ARG4]]
1062+ %2 = affine.load %expand_shape [0 , %arg2 , 0 , %arg4 mod symbol (%dim ), %arg5 ] : memref <1 x8 x1 x?x128 xf16 >
1063+ // CHECK: memref.load %[[CAST]][%[[C0]], %[[ARG2]], %[[APPLY_RES]], %[[ARG5]]] : memref<1x8x?x128xf16>
1064+ affine.store %2 , %alloc [0 , %arg2 , %arg3 , %arg4 , %arg5 ] : memref <1 x8 x4 x?x128 xf16 >
1065+ }
1066+ }
1067+ }
1068+ }
1069+ return
1070+ }
0 commit comments