@@ -131,10 +131,42 @@ func.func @transfer_read_dims_mismatch_non_contiguous_non_zero_indices(
131131
132132// -----
133133
134+ /// The leading dynamic shapes don't affect whether this example is flattenable
135+ /// or not as those dynamic shapes are not candidates for flattening anyway.
136+
137+ func.func @transfer_read_leading_dynamic_dims (
138+ %arg : memref <?x?x8 x4 xi8 , strided <[?, 32 , 4 , 1 ], offset : ?>>,
139+ %idx_1 : index ,
140+ %idx_2 : index ) -> vector <8 x4 xi8 > {
141+
142+ %c0_i8 = arith.constant 0 : i8
143+ %c0 = arith.constant 0 : index
144+ %result = vector.transfer_read %arg [%idx_1 , %idx_2 , %c0 , %c0 ], %c0_i8 {in_bounds = [true , true ]} : memref <?x?x8 x4 xi8 , strided <[?, 32 , 4 , 1 ], offset : ?>>, vector <8 x4 xi8 >
145+ return %result : vector <8 x4 xi8 >
146+ }
147+
148+ // CHECK-LABEL: func @transfer_read_leading_dynamic_dims
149+ // CHECK-SAME: %[[ARG0:.+]]: memref<?x?x8x4xi8, {{.+}}>, %[[ARG1:.+]]: index, %[[ARG2:.+]]: index
150+ // CHECK: %[[C0_I8:.+]] = arith.constant 0 : i8
151+ // CHECK: %[[C0:.+]] = arith.constant 0 : index
152+ // CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG0]] {{\[}}[0], [1], [2, 3]{{\]}}
153+ // CHECK-SAME: : memref<?x?x8x4xi8, {{.+}}> into memref<?x?x32xi8, {{.+}}>
154+ // CHECK: %[[VEC1D:.+]] = vector.transfer_read %[[COLLAPSED]]
155+ // CHECK-SAME: [%[[ARG1]], %[[ARG2]], %[[C0]]], %[[C0_I8]]
156+ // CHECK-SAME: {in_bounds = [true]}
157+ // CHECK-SAME: : memref<?x?x32xi8, {{.+}}>, vector<32xi8>
158+ // CHECK: %[[VEC2D:.+]] = vector.shape_cast %[[VEC1D]] : vector<32xi8> to vector<8x4xi8>
159+ // CHECK: return %[[VEC2D]] : vector<8x4xi8>
160+
161+ // CHECK-128B-LABEL: func @transfer_read_leading_dynamic_dims
162+ // CHECK-128B: memref.collapse_shape
163+
164+ // -----
165+
134166// The input memref has a dynamic trailing shape and hence is not flattened.
135167// TODO: This case could be supported via memref.dim
136168
137- func.func @transfer_read_dims_mismatch_non_zero_indices_dynamic_shapes (
169+ func.func @transfer_read_dims_mismatch_non_zero_indices_trailing_dynamic_dim (
138170 %idx_1: index ,
139171 %idx_2: index ,
140172 %m_in: memref <1 x?x4 x6 xi32 >) -> vector <1 x2 x6 xi32 > {
@@ -146,11 +178,11 @@ func.func @transfer_read_dims_mismatch_non_zero_indices_dynamic_shapes(
146178 return %v : vector <1 x2 x6 xi32 >
147179}
148180
149- // CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_zero_indices_dynamic_shapes(
181+ // CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_zero_indices_trailing_dynamic_dim
150182// CHECK-NOT: memref.collapse_shape
151183// CHECK-NOT: vector.shape_cast
152184
153- // CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_zero_indices_dynamic_shapes(
185+ // CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_zero_indices_trailing_dynamic_dim
154186// CHECK-128B-NOT: memref.collapse_shape
155187
156188// -----
@@ -345,10 +377,40 @@ func.func @transfer_write_dims_mismatch_non_contiguous_non_zero_indices(
345377
346378// -----
347379
380+ // The leading dynamic shapes don't affect whether this example is flattenable
381+ // or not as those dynamic shapes are not candidates for flattening anyway.
382+
383+ func.func @transfer_write_leading_dynamic_dims (
384+ %vec : vector <8 x4 xi8 >,
385+ %arg : memref <?x?x8 x4 xi8 , strided <[?, 32 , 4 , 1 ], offset : ?>>,
386+ %idx_1 : index ,
387+ %idx_2 : index ) {
388+
389+ %c0 = arith.constant 0 : index
390+ vector.transfer_write %vec , %arg [%idx_1 , %idx_2 , %c0 , %c0 ] {in_bounds = [true , true ]} : vector <8 x4 xi8 >, memref <?x?x8 x4 xi8 , strided <[?, 32 , 4 , 1 ], offset : ?>>
391+ return
392+ }
393+
394+ // CHECK-LABEL: func @transfer_write_leading_dynamic_dims
395+ // CHECK-SAME: %[[ARG0:.+]]: vector<8x4xi8>, %[[ARG1:.+]]: memref<?x?x8x4xi8, {{.+}}>, %[[ARG2:.+]]: index, %[[ARG3:.+]]: index
396+ // CHECK: %[[C0:.+]] = arith.constant 0 : index
397+ // CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG1]] {{\[}}[0], [1], [2, 3]{{\]}}
398+ // CHECK-SAME: : memref<?x?x8x4xi8, {{.+}}> into memref<?x?x32xi8, {{.+}}>
399+ // CHECK: %[[VEC1D:.+]] = vector.shape_cast %[[ARG0]] : vector<8x4xi8> to vector<32xi8>
400+ // CHECK: vector.transfer_write %[[VEC1D]], %[[COLLAPSED]]
401+ // CHECK-SAME: [%[[ARG2]], %[[ARG3]], %[[C0]]]
402+ // CHECK-SAME: {in_bounds = [true]}
403+ // CHECK-SAME: : vector<32xi8>, memref<?x?x32xi8, {{.+}}>
404+
405+ // CHECK-128B-LABEL: func @transfer_write_leading_dynamic_dims
406+ // CHECK-128B: memref.collapse_shape
407+
408+ // -----
409+
348410// The input memref has a dynamic trailing shape and hence is not flattened.
349411// TODO: This case could be supported via memref.dim
350412
351- func.func @transfer_write_dims_mismatch_non_zero_indices_dynamic_shapes (
413+ func.func @transfer_write_dims_mismatch_non_zero_indices_trailing_dynamic_dim (
352414 %idx_1: index ,
353415 %idx_2: index ,
354416 %vec : vector <1 x2 x6 xi32 >,
@@ -361,11 +423,11 @@ func.func @transfer_write_dims_mismatch_non_zero_indices_dynamic_shapes(
361423 return
362424}
363425
364- // CHECK-LABEL: func.func @transfer_write_dims_mismatch_non_zero_indices_dynamic_shapes (
426+ // CHECK-LABEL: func.func @transfer_write_dims_mismatch_non_zero_indices_trailing_dynamic_dim (
365427// CHECK-NOT: memref.collapse_shape
366428// CHECK-NOT: vector.shape_cast
367429
368- // CHECK-128B-LABEL: func @transfer_write_dims_mismatch_non_zero_indices_dynamic_shapes (
430+ // CHECK-128B-LABEL: func @transfer_write_dims_mismatch_non_zero_indices_trailing_dynamic_dim (
369431// CHECK-128B-NOT: memref.collapse_shape
370432
371433// -----
@@ -434,56 +496,10 @@ func.func @transfer_write_non_contiguous_src(
434496// -----
435497
436498///----------------------------------------------------------------------------------------
437- /// TODO: Categorize + re-format
499+ /// [Pattern: DropUnitDimFromElementwiseOps]
500+ /// TODO: Move to a dedicated file - there's no "flattening" in the following tests
438501///----------------------------------------------------------------------------------------
439502
440- func.func @transfer_read_flattenable_with_dynamic_dims_and_indices (%arg0 : memref <?x?x8 x4 xi8 , strided <[?, 32 , 4 , 1 ], offset : ?>>, %arg1 : index , %arg2 : index ) -> vector <8 x4 xi8 > {
441- %c0_i8 = arith.constant 0 : i8
442- %c0 = arith.constant 0 : index
443- %result = vector.transfer_read %arg0 [%arg1 , %arg2 , %c0 , %c0 ], %c0_i8 {in_bounds = [true , true ]} : memref <?x?x8 x4 xi8 , strided <[?, 32 , 4 , 1 ], offset : ?>>, vector <8 x4 xi8 >
444- return %result : vector <8 x4 xi8 >
445- }
446-
447- // CHECK-LABEL: func @transfer_read_flattenable_with_dynamic_dims_and_indices
448- // CHECK-SAME: %[[ARG0:.+]]: memref<?x?x8x4xi8, {{.+}}>, %[[ARG1:.+]]: index, %[[ARG2:.+]]: index
449- // CHECK: %[[C0_I8:.+]] = arith.constant 0 : i8
450- // CHECK: %[[C0:.+]] = arith.constant 0 : index
451- // CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG0]] {{\[}}[0], [1], [2, 3]{{\]}}
452- // CHECK-SAME: : memref<?x?x8x4xi8, {{.+}}> into memref<?x?x32xi8, {{.+}}>
453- // CHECK: %[[VEC1D:.+]] = vector.transfer_read %[[COLLAPSED]]
454- // CHECK-SAME: [%[[ARG1]], %[[ARG2]], %[[C0]]], %[[C0_I8]]
455- // CHECK-SAME: {in_bounds = [true]}
456- // CHECK-SAME: : memref<?x?x32xi8, {{.+}}>, vector<32xi8>
457- // CHECK: %[[VEC2D:.+]] = vector.shape_cast %[[VEC1D]] : vector<32xi8> to vector<8x4xi8>
458- // CHECK: return %[[VEC2D]] : vector<8x4xi8>
459-
460- // CHECK-128B-LABEL: func @transfer_read_flattenable_with_dynamic_dims_and_indices(
461- // CHECK-128B: memref.collapse_shape
462-
463- // -----
464-
465- func.func @transfer_write_flattenable_with_dynamic_dims_and_indices (%vec : vector <8 x4 xi8 >, %dst : memref <?x?x8 x4 xi8 , strided <[?, 32 , 4 , 1 ], offset : ?>>, %arg1 : index , %arg2 : index ) {
466- %c0 = arith.constant 0 : index
467- vector.transfer_write %vec , %dst [%arg1 , %arg2 , %c0 , %c0 ] {in_bounds = [true , true ]} : vector <8 x4 xi8 >, memref <?x?x8 x4 xi8 , strided <[?, 32 , 4 , 1 ], offset : ?>>
468- return
469- }
470-
471- // CHECK-LABEL: func @transfer_write_flattenable_with_dynamic_dims_and_indices
472- // CHECK-SAME: %[[ARG0:.+]]: vector<8x4xi8>, %[[ARG1:.+]]: memref<?x?x8x4xi8, {{.+}}>, %[[ARG2:.+]]: index, %[[ARG3:.+]]: index
473- // CHECK: %[[C0:.+]] = arith.constant 0 : index
474- // CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG1]] {{\[}}[0], [1], [2, 3]{{\]}}
475- // CHECK-SAME: : memref<?x?x8x4xi8, {{.+}}> into memref<?x?x32xi8, {{.+}}>
476- // CHECK: %[[VEC1D:.+]] = vector.shape_cast %[[ARG0]] : vector<8x4xi8> to vector<32xi8>
477- // CHECK: vector.transfer_write %[[VEC1D]], %[[COLLAPSED]]
478- // CHECK-SAME: [%[[ARG2]], %[[ARG3]], %[[C0]]]
479- // CHECK-SAME: {in_bounds = [true]}
480- // CHECK-SAME: : vector<32xi8>, memref<?x?x32xi8, {{.+}}>
481-
482- // CHECK-128B-LABEL: func @transfer_write_flattenable_with_dynamic_dims_and_indices(
483- // CHECK-128B: memref.collapse_shape
484-
485- // -----
486-
487503func.func @fold_unit_dim_add_basic (%arg0 : vector <1 x8 xi32 >) -> vector <1 x8 xi32 > {
488504 %add = arith.addi %arg0 , %arg0 : vector <1 x8 xi32 >
489505 return %add : vector <1 x8 xi32 >
0 commit comments