@@ -164,6 +164,37 @@ func.func @cast_away_contraction_leading_one_dims_nonleadingunitdim_rank4_acctra
164164 return %0: vector <1 x1 x2 x16 xf32 >
165165}
166166
167+ // -----
168+
169+ // CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
170+ // CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
171+ // CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
172+
173+ // CHECK-LABEL: not_insert_cast_for_contraction_under_mask
174+ // CHECK: %[[MASK:.+]] = vector.constant_mask
175+ // CHECK: %[[CASTED_MASK:.+]] = vector.broadcast %[[MASK]]
176+ // CHECK: %[[RET:.+]] = vector.mask %[[CASTED_MASK]] {
177+ // CHECK-SAME: vector.contract {{.*}} : vector<1x16x8xf32>, vector<1x8x16xf32> into vector<1x16x16xf32> }
178+ // CHECK: return %[[RET]] : vector<1x16x16xf32>
179+
180+ #contraction_accesses0 = [
181+ affine_map <(l , i , j , k ) -> (l , i , k )>,
182+ affine_map <(l , i , j , k ) -> (l , k , j )>,
183+ affine_map <(l , i , j , k ) -> (l , i , j )>
184+ ]
185+ #contraction_trait0 = {
186+ indexing_maps = #contraction_accesses0 ,
187+ iterator_types = [" parallel" , " parallel" , " parallel" , " reduction" ]
188+ }
189+
190+ func.func @not_insert_cast_for_contraction_under_mask (%arg0: vector <1 x16 x8 xf32 >, %arg1: vector <1 x8 x16 xf32 >, %arg2: vector <1 x16 x16 xf32 >) -> vector <1 x16 x16 xf32 > {
191+ %mask = vector.constant_mask [1 , 15 , 15 , 8 ] : vector <1 x16 x16 x8 xi1 >
192+ %0 = vector.mask %mask {
193+ vector.contract #contraction_trait0 %arg0 , %arg1 , %arg2 : vector <1 x16 x8 xf32 >, vector <1 x8 x16 xf32 > into vector <1 x16 x16 xf32 >
194+ } : vector <1 x16 x16 x8 xi1 > -> vector <1 x16 x16 xf32 >
195+ return %0 : vector <1 x16 x16 xf32 >
196+ }
197+
167198// -----
168199// CHECK-LABEL: func @cast_away_extract_strided_slice_leading_one_dims
169200func.func @cast_away_extract_strided_slice_leading_one_dims (%arg0: vector <1 x8 x8 xf16 >) -> vector <1 x1 x8 xf16 > {
@@ -253,6 +284,24 @@ func.func @cast_away_nontrivial_map_masked_transfer_read(%arg0: memref<1x4x8xf16
253284
254285// -----
255286
287+ // CHECK-LABEL: func @not_insert_cast_fo4_transfer_read_under_mask
288+ // CHECK: %[[MASK:.+]] = vector.constant_mask
289+ // CHECK: %[[CASTED_MASK:.+]] = vector.broadcast %[[MASK]]
290+ // CHECK: %[[RET:.+]] = vector.mask %[[CASTED_MASK]] {
291+ // CHECK-SAME: vector.transfer_read {{.*}} : memref<1x1x4xf16>, vector<1x4xf16> }
292+ // CHECK: return %[[RET]] : vector<1x4xf16>
293+ func.func @not_insert_cast_fo4_transfer_read_under_mask (%arg0: memref <1 x1 x4 xf16 >) -> vector <1 x4 xf16 > {
294+ %c0 = arith.constant 0 : index
295+ %f0 = arith.constant 0. : f16
296+ %mask = vector.constant_mask [1 , 3 ] : vector <1 x4 xi1 >
297+ %ret = vector.mask %mask {
298+ vector.transfer_read %arg0 [%c0 , %c0 , %c0 ], %f0 {in_bounds = [true , true ]} : memref <1 x1 x4 xf16 >, vector <1 x4 xf16 >
299+ } : vector <1 x4 xi1 > -> vector <1 x4 xf16 >
300+ return %ret: vector <1 x4 xf16 >
301+ }
302+
303+ // -----
304+
256305// CHECK-LABEL: func @cast_away_transfer_write_leading_one_dims
257306func.func @cast_away_transfer_write_leading_one_dims (%arg0: memref <1 x4 x8 x16 xf16 >, %arg1: vector <1 x4 xf16 >) {
258307 // CHECK: %[[C0:.+]] = arith.constant 0 : index
@@ -286,6 +335,23 @@ func.func @cast_away_transfer_write_leading_one_dims_one_element(%arg0: memref<1
286335
287336// -----
288337
338+ // CHECK-LABEL: func @not_insert_cast_for_transfer_write_under_mask
339+ // CHECK: %[[MASK:.+]] = vector.constant_mask
340+ // CHECK: %[[CASTED_MASK:.+]] = vector.broadcast %[[MASK]]
341+ // CHECK: vector.mask %[[CASTED_MASK]] {
342+ // CHECK-SAME: vector.transfer_write {{.*}} : vector<1x4xf16>, memref<1x1x4xf16> }
343+ // CHECK: return
344+ func.func @not_insert_cast_for_transfer_write_under_mask (%arg0: memref <1 x1 x4 xf16 >, %arg1: vector <1 x4 xf16 >) {
345+ %c0 = arith.constant 0 : index
346+ %mask = vector.constant_mask [1 , 3 ] : vector <1 x4 xi1 >
347+ vector.mask %mask {
348+ vector.transfer_write %arg1 , %arg0 [%c0 , %c0 , %c0 ] {in_bounds = [true , true ]} : vector <1 x4 xf16 >, memref <1 x1 x4 xf16 >
349+ } : vector <1 x4 xi1 >
350+ return
351+ }
352+
353+ // -----
354+
289355// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2) -> (d1)>
290356// CHECK-LABEL: func @cast_away_nontrivial_map_masked_transfer_write
291357func.func @cast_away_nontrivial_map_masked_transfer_write (%arg0: memref <1 x4 x8 xf16 >, %arg1: vector <1 x1 x4 xf16 >, %arg2: vector <1 x4 x1 xi1 >) {
0 commit comments