@@ -635,3 +635,106 @@ func.func @vector_print_scalable_vector(%arg0: vector<[4]xi32>) {
635635// CHECK: vector.print
636636// CHECK: return
637637// CHECK: }
638+
639+ // -----
640+
641+ func.func @transfer_read_array_of_scalable (%arg0: memref <3 x?xf32 >) -> vector <3 x[4 ]xf32 > {
642+ %c0 = arith.constant 0 : index
643+ %c1 = arith.constant 1 : index
644+ %cst = arith.constant 0.000000e+00 : f32
645+ %dim = memref.dim %arg0 , %c1 : memref <3 x?xf32 >
646+ %mask = vector.create_mask %c1 , %dim : vector <3 x[4 ]xi1 >
647+ %read = vector.transfer_read %arg0 [%c0 , %c0 ], %cst , %mask {in_bounds = [true , true ]} : memref <3 x?xf32 >, vector <3 x[4 ]xf32 >
648+ return %read : vector <3 x[4 ]xf32 >
649+ }
650+ // CHECK-LABEL: func.func @transfer_read_array_of_scalable(
651+ // CHECK-SAME: %[[ARG:.*]]: memref<3x?xf32>) -> vector<3x[4]xf32> {
652+ // CHECK: %[[PADDING:.*]] = arith.constant 0.000000e+00 : f32
653+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
654+ // CHECK: %[[C3:.*]] = arith.constant 3 : index
655+ // CHECK: %[[C1:.*]] = arith.constant 1 : index
656+ // CHECK: %[[ALLOCA_VEC:.*]] = memref.alloca() : memref<vector<3x[4]xf32>>
657+ // CHECK: %[[ALLOCA_MASK:.*]] = memref.alloca() : memref<vector<3x[4]xi1>>
658+ // CHECK: %[[DIM_SIZE:.*]] = memref.dim %[[ARG]], %[[C1]] : memref<3x?xf32>
659+ // CHECK: %[[MASK:.*]] = vector.create_mask %[[C1]], %[[DIM_SIZE]] : vector<3x[4]xi1>
660+ // CHECK: memref.store %[[MASK]], %[[ALLOCA_MASK]][] : memref<vector<3x[4]xi1>>
661+ // CHECK: %[[UNPACK_VECTOR:.*]] = vector.type_cast %[[ALLOCA_VEC]] : memref<vector<3x[4]xf32>> to memref<3xvector<[4]xf32>>
662+ // CHECK: %[[UNPACK_MASK:.*]] = vector.type_cast %[[ALLOCA_MASK]] : memref<vector<3x[4]xi1>> to memref<3xvector<[4]xi1>>
663+ // CHECK: scf.for %[[VAL_11:.*]] = %[[C0]] to %[[C3]] step %[[C1]] {
664+ // CHECK: %[[MASK_SLICE:.*]] = memref.load %[[UNPACK_MASK]]{{\[}}%[[VAL_11]]] : memref<3xvector<[4]xi1>>
665+ // CHECK: %[[READ_SLICE:.*]] = vector.transfer_read %[[ARG]]{{\[}}%[[VAL_11]], %[[C0]]], %[[PADDING]], %[[MASK_SLICE]] {in_bounds = [true]} : memref<3x?xf32>, vector<[4]xf32>
666+ // CHECK: memref.store %[[READ_SLICE]], %[[UNPACK_VECTOR]]{{\[}}%[[VAL_11]]] : memref<3xvector<[4]xf32>>
667+ // CHECK: }
668+ // CHECK: %[[RESULT:.*]] = memref.load %[[ALLOCA_VEC]][] : memref<vector<3x[4]xf32>>
669+ // CHECK: return %[[RESULT]] : vector<3x[4]xf32>
670+ // CHECK: }
671+
672+ // -----
673+
674+ func.func @transfer_write_array_of_scalable (%vec: vector <3 x[4 ]xf32 >, %arg0: memref <3 x?xf32 >) {
675+ %c0 = arith.constant 0 : index
676+ %c1 = arith.constant 1 : index
677+ %cst = arith.constant 0.000000e+00 : f32
678+ %dim = memref.dim %arg0 , %c1 : memref <3 x?xf32 >
679+ %mask = vector.create_mask %c1 , %dim : vector <3 x[4 ]xi1 >
680+ vector.transfer_write %vec , %arg0 [%c0 , %c0 ], %mask {in_bounds = [true , true ]} : vector <3 x[4 ]xf32 >, memref <3 x?xf32 >
681+ return
682+ }
683+ // CHECK-LABEL: func.func @transfer_write_array_of_scalable(
684+ // CHECK-SAME: %[[VEC:.*]]: vector<3x[4]xf32>,
685+ // CHECK-SAME: %[[MEMREF:.*]]: memref<3x?xf32>) {
686+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
687+ // CHECK: %[[C3:.*]] = arith.constant 3 : index
688+ // CHECK: %[[C1:.*]] = arith.constant 1 : index
689+ // CHECK: %[[ALLOCA_VEC:.*]] = memref.alloca() : memref<vector<3x[4]xf32>>
690+ // CHECK: %[[ALLOCA_MASK:.*]] = memref.alloca() : memref<vector<3x[4]xi1>>
691+ // CHECK: %[[DIM_SIZE:.*]] = memref.dim %[[MEMREF]], %[[C1]] : memref<3x?xf32>
692+ // CHECK: %[[MASK:.*]] = vector.create_mask %[[C1]], %[[DIM_SIZE]] : vector<3x[4]xi1>
693+ // CHECK: memref.store %[[MASK]], %[[ALLOCA_MASK]][] : memref<vector<3x[4]xi1>>
694+ // CHECK: memref.store %[[VEC]], %[[ALLOCA_VEC]][] : memref<vector<3x[4]xf32>>
695+ // CHECK: %[[UNPACK_VECTOR:.*]] = vector.type_cast %[[ALLOCA_VEC]] : memref<vector<3x[4]xf32>> to memref<3xvector<[4]xf32>>
696+ // CHECK: %[[UNPACK_MASK:.*]] = vector.type_cast %[[ALLOCA_MASK]] : memref<vector<3x[4]xi1>> to memref<3xvector<[4]xi1>>
697+ // CHECK: scf.for %[[VAL_11:.*]] = %[[C0]] to %[[C3]] step %[[C1]] {
698+ // CHECK: %[[MASK_SLICE:.*]] = memref.load %[[UNPACK_VECTOR]]{{\[}}%[[VAL_11]]] : memref<3xvector<[4]xf32>>
699+ // CHECK: %[[VECTOR_SLICE:.*]] = memref.load %[[UNPACK_MASK]]{{\[}}%[[VAL_11]]] : memref<3xvector<[4]xi1>>
700+ // CHECK: vector.transfer_write %[[MASK_SLICE]], %[[MEMREF]]{{\[}}%[[VAL_11]], %[[C0]]], %[[VECTOR_SLICE]] {in_bounds = [true]} : vector<[4]xf32>, memref<3x?xf32>
701+ // CHECK: }
702+ // CHECK: return
703+ // CHECK: }
704+
705+ // -----
706+
707+ /// The following two tests currently cannot be lowered via unpacking the leading dim since it is scalable.
708+ /// It may be possible to special case this via a dynamic dim in future.
709+
710+ func.func @cannot_lower_transfer_write_with_leading_scalable (%vec: vector <[4 ]x4 xf32 >, %arg0: memref <?x4 xf32 >) {
711+ %c0 = arith.constant 0 : index
712+ %c4 = arith.constant 4 : index
713+ %cst = arith.constant 0.000000e+00 : f32
714+ %dim = memref.dim %arg0 , %c0 : memref <?x4 xf32 >
715+ %mask = vector.create_mask %dim , %c4 : vector <[4 ]x4 xi1 >
716+ vector.transfer_write %vec , %arg0 [%c0 , %c0 ], %mask {in_bounds = [true , true ]} : vector <[4 ]x4 xf32 >, memref <?x4 xf32 >
717+ return
718+ }
719+ // CHECK-LABEL: func.func @cannot_lower_transfer_write_with_leading_scalable(
720+ // CHECK-SAME: %[[VEC:.*]]: vector<[4]x4xf32>,
721+ // CHECK-SAME: %[[MEMREF:.*]]: memref<?x4xf32>)
722+ // CHECK: vector.transfer_write %[[VEC]], %[[MEMREF]][%{{.*}}, %{{.*}}], %{{.*}} {in_bounds = [true, true]} : vector<[4]x4xf32>, memref<?x4xf32>
723+
724+ // -----
725+
726+ func.func @cannot_lower_transfer_read_with_leading_scalable (%arg0: memref <?x4 xf32 >) -> vector <[4 ]x4 xf32 > {
727+ %c0 = arith.constant 0 : index
728+ %c1 = arith.constant 1 : index
729+ %c4 = arith.constant 4 : index
730+ %cst = arith.constant 0.000000e+00 : f32
731+ %dim = memref.dim %arg0 , %c0 : memref <?x4 xf32 >
732+ %mask = vector.create_mask %dim , %c4 : vector <[4 ]x4 xi1 >
733+ %read = vector.transfer_read %arg0 [%c0 , %c0 ], %cst , %mask {in_bounds = [true , true ]} : memref <?x4 xf32 >, vector <[4 ]x4 xf32 >
734+ return %read : vector <[4 ]x4 xf32 >
735+ }
736+ // CHECK-LABEL: func.func @cannot_lower_transfer_read_with_leading_scalable(
737+ // CHECK-SAME: %[[MEMREF:.*]]: memref<?x4xf32>)
738+ // CHECK: %{{.*}} = vector.transfer_read %[[MEMREF]][%{{.*}}, %{{.*}}], %{{.*}}, %{{.*}} {in_bounds = [true, true]} : memref<?x4xf32>, vector<[4]x4xf32>
739+
740+
0 commit comments