@@ -416,9 +416,8 @@ func.func @lift_illegal_transpose_to_memory(%a: index, %b: index, %memref: memre
416416 // CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
417417 // CHECK-DAG: %[[C8_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C8]] : index
418418 // CHECK-NEXT: %[[READ_SUBVIEW:.*]] = memref.subview %[[MEMREF]][%[[INDEXA]], %[[INDEXB]]] [%[[C8_VSCALE]], 4] [1, 1] : memref<?x?xf32> to memref<?x4xf32, strided<[?, 1], offset: ?>>
419- // CHECK-NEXT: %[[CAST:.*]] = memref.cast %[[READ_SUBVIEW]] : memref<?x4xf32, strided<[?, 1], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
420- // CHECK-NEXT: %[[TRANSPOSE:.*]] = memref.transpose %[[CAST]] (d0, d1) -> (d1, d0) : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
421- // CHECK-NEXT: %[[LEGAL_READ:.*]] = vector.transfer_read %[[TRANSPOSE]][%c0, %c0], %[[C0_F32]] : memref<?x?xf32, strided<[?, ?], offset: ?>>, vector<4x[8]xf32>
419+ // CHECK-NEXT: %[[TRANSPOSE:.*]] = memref.transpose %[[READ_SUBVIEW]] (d0, d1) -> (d1, d0) : memref<?x4xf32, strided<[?, 1], offset: ?>> to memref<4x?xf32, strided<[1, ?], offset: ?>>
420+ // CHECK-NEXT: %[[LEGAL_READ:.*]] = vector.transfer_read %[[TRANSPOSE]][%c0, %c0], %[[C0_F32]] {in_bounds = [true, false]} : memref<4x?xf32, strided<[1, ?], offset: ?>>, vector<4x[8]xf32>
422421 // CHECK-NEXT: return %[[LEGAL_READ]]
423422 %pad = arith.constant 0.0 : f32
424423 %illegalRead = vector.transfer_read %memref [%a , %b ], %pad : memref <?x?xf32 >, vector <[8 ]x4 xf32 >
@@ -434,11 +433,10 @@ func.func @lift_illegal_transpose_to_memory(%a: index, %b: index, %memref: memre
434433// CHECK-SAME: %[[MEMREF:[a-z0-9]+]]: memref<?x?xf32>
435434func.func @lift_illegal_transpose_to_memory_with_mask (%dim0: index , %dim1: index , %memref: memref <?x?xf32 >, %a: index , %b: index ) -> vector <4 x[8 ]xf32 > {
436435 // CHECK-DAG: %[[READ_SUBVIEW:.*]] = memref.subview %[[MEMREF]]
437- // CHECK-DAG: %[[CAST:.*]] = memref.cast %[[READ_SUBVIEW]]
438- // CHECK-DAG: %[[TRANSPOSE:.*]] = memref.transpose %[[CAST]]
436+ // CHECK-DAG: %[[TRANSPOSE:.*]] = memref.transpose %[[READ_SUBVIEW]]
439437 // CHECK-DAG: %[[MASK:.*]] = vector.create_mask %[[DIM1]], %[[DIM0]] : vector<4x[8]xi1>
440438 // CHECK: %[[LEGAL_READ:.*]] = vector.transfer_read %[[TRANSPOSE]]
441- // CHECK-SAME: %[[MASK]] : memref<?x? xf32, strided<[? , ?], offset: ?>>, vector<4x[8]xf32>
439+ // CHECK-SAME: %[[MASK]] {in_bounds = [true, false]} : memref<4x? xf32, strided<[1 , ?], offset: ?>>, vector<4x[8]xf32>
442440 // CHECK-NEXT: return %[[LEGAL_READ]]
443441 %pad = arith.constant 0.0 : f32
444442 %mask = vector.create_mask %dim0 , %dim1 : vector <[8 ]x4 xi1 >
@@ -453,8 +451,7 @@ func.func @lift_illegal_transpose_to_memory_with_mask(%dim0: index, %dim1: index
453451// CHECK-SAME: %[[MEMREF:[a-z0-9]+]]: memref<?x?xi8>
454452func.func @lift_illegal_transpose_to_memory_with_arith_extop (%a: index , %b: index , %memref: memref <?x?xi8 >) -> vector <4 x[8 ]xi32 > {
455453 // CHECK-DAG: %[[READ_SUBVIEW:.*]] = memref.subview %[[MEMREF]]
456- // CHECK-DAG: %[[CAST:.*]] = memref.cast %[[READ_SUBVIEW]]
457- // CHECK-DAG: %[[TRANSPOSE:.*]] = memref.transpose %[[CAST]]
454+ // CHECK-DAG: %[[TRANSPOSE:.*]] = memref.transpose %[[READ_SUBVIEW]]
458455 // CHECK: %[[LEGAL_READ:.*]] = vector.transfer_read %[[TRANSPOSE]]
459456 // CHECK-NEXT: %[[EXT_TYPE:.*]] = arith.extsi %[[LEGAL_READ]] : vector<4x[8]xi8> to vector<4x[8]xi32>
460457 // CHECK-NEXT: return %[[EXT_TYPE]]
@@ -514,7 +511,7 @@ func.func @illegal_shape_cast_to_transpose_1d(%vec: vector<[4]x1xf32>) -> vector
514511
515512// CHECK-LABEL: @lift_illegal_2d_shape_cast_to_memory
516513func.func @lift_illegal_2d_shape_cast_to_memory (%a: index , %b: index , %memref: memref <?x?xf32 >) -> vector <1 x[4 ]xf32 > {
517- // CHECK: vector.transfer_read {{.*}} : memref<?x ?xf32, {{.*}}>, vector<1x[4]xf32>
514+ // CHECK: vector.transfer_read {{.*}} : memref<1x ?xf32, {{.*}}>, vector<1x[4]xf32>
518515 // CHECK-NOT: vector.shape_cast
519516 %pad = arith.constant 0.0 : f32
520517 %illegalRead = vector.transfer_read %memref [%a , %b ], %pad {in_bounds = [false , true ]}: memref <?x?xf32 >, vector <[4 ]x1 xf32 >
@@ -526,7 +523,7 @@ func.func @lift_illegal_2d_shape_cast_to_memory(%a: index, %b: index, %memref: m
526523
527524// CHECK-LABEL: @lift_illegal_1d_shape_cast_to_memory
528525func.func @lift_illegal_1d_shape_cast_to_memory (%a: index , %b: index , %memref: memref <?x?xf32 >) -> vector <[4 ]xf32 > {
529- // CHECK: vector.transfer_read {{.*}} : memref<?x ?xf32, {{.*}}>, vector<1x[4]xf32>
526+ // CHECK: vector.transfer_read {{.*}} : memref<1x ?xf32, {{.*}}>, vector<1x[4]xf32>
530527 // CHECK-NOT: vector.shape_cast {{.*}} : vector<[4]x1xf32> to vector<[4]xf32>
531528 %pad = arith.constant 0.0 : f32
532529 %illegalRead = vector.transfer_read %memref [%a , %b ], %pad {in_bounds = [false , true ]}: memref <?x?xf32 >, vector <[4 ]x1 xf32 >
0 commit comments