@@ -14,7 +14,7 @@ func.func @create_vector_mask_to_constant_mask() -> (vector<4x3xi1>) {
1414// CHECK-LABEL: create_scalable_vector_mask_to_constant_mask
1515func.func @create_scalable_vector_mask_to_constant_mask () -> (vector <[8 ]xi1 >) {
1616 %c -1 = arith.constant -1 : index
17- // CHECK: vector.constant_mask [0] : vector<[8]xi1>
17+ // CHECK: arith.constant dense<false> : vector<[8]xi1>
1818 %0 = vector.create_mask %c -1 : vector <[8 ]xi1 >
1919 return %0 : vector <[8 ]xi1 >
2020}
@@ -36,7 +36,7 @@ func.func @create_vector_mask_to_constant_mask_truncation() -> (vector<4x3xi1>)
3636func.func @create_vector_mask_to_constant_mask_truncation_neg () -> (vector <4 x3 xi1 >) {
3737 %cneg2 = arith.constant -2 : index
3838 %c5 = arith.constant 5 : index
39- // CHECK: vector.constant_mask [0, 0] : vector<4x3xi1>
39+ // CHECK: arith.constant dense<false> : vector<4x3xi1>
4040 %0 = vector.create_mask %c5 , %cneg2 : vector <4 x3 xi1 >
4141 return %0 : vector <4 x3 xi1 >
4242}
@@ -47,7 +47,7 @@ func.func @create_vector_mask_to_constant_mask_truncation_neg() -> (vector<4x3xi
4747func.func @create_vector_mask_to_constant_mask_truncation_zero () -> (vector <4 x3 xi1 >) {
4848 %c2 = arith.constant 2 : index
4949 %c0 = arith.constant 0 : index
50- // CHECK: vector.constant_mask [0, 0] : vector<4x3xi1>
50+ // CHECK: arith.constant dense<false> : vector<4x3xi1>
5151 %0 = vector.create_mask %c0 , %c2 : vector <4 x3 xi1 >
5252 return %0 : vector <4 x3 xi1 >
5353}
@@ -60,7 +60,7 @@ func.func @create_vector_mask_to_constant_mask_scalable_all_true() -> (vector<8x
6060 %c16 = arith.constant 16 : index
6161 %0 = vector.vscale
6262 %1 = arith.muli %0 , %c16 : index
63- // CHECK: vector.constant_mask [8, 16] : vector<8x[16]xi1>
63+ // CHECK: arith.constant dense<true> : vector<8x[16]xi1>
6464 %10 = vector.create_mask %c8 , %1 : vector <8 x[16 ]xi1 >
6565 return %10 : vector <8 x[16 ]xi1 >
6666}
@@ -272,6 +272,30 @@ func.func @extract_from_non_constant_create_mask(%dim0: index) -> vector<[2]xi1>
272272
273273// -----
274274
275+ // CHECK-LABEL: constant_mask_to_true_splat
276+ func.func @constant_mask_to_true_splat () -> vector <2 x4 xi1 > {
277+ // CHECK: arith.constant dense<true>
278+ // CHECK-NOT: vector.constant_mask
279+ %0 = vector.constant_mask [2 , 4 ] : vector <2 x4 xi1 >
280+ return %0 : vector <2 x4 xi1 >
281+ }
282+
283+ // CHECK-LABEL: constant_mask_to_false_splat
284+ func.func @constant_mask_to_false_splat () -> vector <2 x4 xi1 > {
285+ // CHECK: arith.constant dense<false>
286+ // CHECK-NOT: vector.constant_mask
287+ %0 = vector.constant_mask [0 , 0 ] : vector <2 x4 xi1 >
288+ return %0 : vector <2 x4 xi1 >
289+ }
290+
291+ // CHECK-LABEL: constant_mask_to_true_splat_0d
292+ func.func @constant_mask_to_true_splat_0d () -> vector <i1 > {
293+ // CHECK: arith.constant dense<true>
294+ // CHECK-NOT: vector.constant_mask
295+ %0 = vector.constant_mask [1 ] : vector <i1 >
296+ return %0 : vector <i1 >
297+ }
298+
275299// CHECK-LABEL: constant_mask_transpose_to_transposed_constant_mask
276300func.func @constant_mask_transpose_to_transposed_constant_mask () -> (vector <2 x3 x4 xi1 >, vector <4 x2 x3 xi1 >) {
277301 // CHECK: vector.constant_mask [1, 2, 3] : vector<2x3x4xi1>
@@ -289,7 +313,7 @@ func.func @extract_strided_slice_of_constant_mask() -> (vector<2x2xi1>) {
289313 %1 = vector.extract_strided_slice %0
290314 {offsets = [0 , 0 ], sizes = [2 , 2 ], strides = [1 , 1 ]}
291315 : vector <4 x3 xi1 > to vector <2 x2 xi1 >
292- // CHECK: vector.constant_mask [2, 2] : vector<2x2xi1>
316+ // CHECK: arith.constant dense<true> : vector<2x2xi1>
293317 return %1 : vector <2 x2 xi1 >
294318}
295319
@@ -322,7 +346,7 @@ func.func @extract_strided_slice_of_constant_mask() -> (vector<2x2xi1>) {
322346 %1 = vector.extract_strided_slice %0
323347 {offsets = [2 , 0 ], sizes = [2 , 2 ], strides = [1 , 1 ]}
324348 : vector <4 x3 xi1 > to vector <2 x2 xi1 >
325- // CHECK: vector.constant_mask [0, 0] : vector<2x2xi1>
349+ // CHECK: arith.constant dense<false> : vector<2x2xi1>
326350 return %1 : vector <2 x2 xi1 >
327351}
328352
@@ -333,7 +357,7 @@ func.func @extract_strided_slice_of_constant_mask() -> (vector<2x1xi1>) {
333357 %1 = vector.extract_strided_slice %0
334358 {offsets = [0 , 2 ], sizes = [2 , 1 ], strides = [1 , 1 ]}
335359 : vector <4 x3 xi1 > to vector <2 x1 xi1 >
336- // CHECK: vector.constant_mask [0, 0] : vector<2x1xi1>
360+ // CHECK: arith.constant dense<false> : vector<2x1xi1>
337361 return %1 : vector <2 x1 xi1 >
338362}
339363
@@ -344,7 +368,7 @@ func.func @extract_strided_slice_of_constant_mask() -> (vector<2x1xi1>) {
344368 %1 = vector.extract_strided_slice %0
345369 {offsets = [0 , 1 ], sizes = [2 , 1 ], strides = [1 , 1 ]}
346370 : vector <4 x3 xi1 > to vector <2 x1 xi1 >
347- // CHECK: vector.constant_mask [2, 1] : vector<2x1xi1>
371+ // CHECK: arith.constant dense<true> : vector<2x1xi1>
348372 return %1 : vector <2 x1 xi1 >
349373}
350374
0 commit comments