@@ -380,9 +380,14 @@ func.func @vector_broadcast_with_tailing_unit_dim(%v: vector<4x1xf32>) -> vector
380380// CHECK: return [[r3]] : vector<4x4xf32>
381381
382382
383- // CHECK-LABEL: func.func @unroll_2D_vector_load(
383+ func.func @vector_load_2D (%mem: memref <4 x4 xf16 >) -> vector <4 x4 xf16 > {
384+ %c0 = arith.constant 0 : index
385+ %0 = vector.load %mem [%c0 , %c0 ] : memref <4 x4 xf16 >, vector <4 x4 xf16 >
386+ return %0 : vector <4 x4 xf16 >
387+ }
388+
389+ // CHECK-LABEL: func.func @vector_load_2D(
384390// CHECK-SAME: %[[ARG:.*]]: memref<4x4xf16>) -> vector<4x4xf16> {
385- func.func @unroll_2D_vector_load (%arg0: memref <4 x4 xf16 >) -> vector <4 x4 xf16 > {
386391 // CHECK: %[[C3:.*]] = arith.constant 3 : index
387392 // CHECK: %[[C2:.*]] = arith.constant 2 : index
388393 // CHECK: %[[C1:.*]] = arith.constant 1 : index
@@ -397,14 +402,16 @@ func.func @unroll_2D_vector_load(%arg0: memref<4x4xf16>) -> vector<4x4xf16> {
397402 // CHECK: %[[V6:.*]] = vector.load %[[ARG]][%[[C3]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
398403 // CHECK: %[[V7:.*]] = vector.insert_strided_slice %[[V6]], %[[V5]] {offsets = [3, 0], strides = [1]} : vector<4xf16> into vector<4x4xf16>
399404 // CHECK: return %[[V7]] : vector<4x4xf16>
405+
406+
407+ func.func @vector_store_2D (%mem: memref <4 x4 xf16 >, %v: vector <4 x4 xf16 >) {
400408 %c0 = arith.constant 0 : index
401- %0 = vector.load %arg0 [%c0 , %c0 ] : memref <4 x4 xf16 >, vector <4 x4 xf16 >
402- return %0 : vector < 4 x 4 x f16 >
409+ vector.store %v , %mem [%c0 , %c0 ] : memref <4 x4 xf16 >, vector <4 x4 xf16 >
410+ return
403411}
404412
405- // CHECK-LABEL: func.func @unroll_2D_vector_store (
413+ // CHECK-LABEL: func.func @vector_store_2D (
406414// CHECK-SAME: %[[ARG0:.*]]: memref<4x4xf16>, %[[ARG1:.*]]: vector<4x4xf16>) {
407- func.func @unroll_2D_vector_store (%arg0: memref <4 x4 xf16 >, %arg1: vector <4 x4 xf16 >) {
408415 // CHECK: %[[C3:.*]] = arith.constant 3 : index
409416 // CHECK: %[[C2:.*]] = arith.constant 2 : index
410417 // CHECK: %[[C1:.*]] = arith.constant 1 : index
@@ -417,14 +424,16 @@ func.func @unroll_2D_vector_store(%arg0: memref<4x4xf16>, %arg1: vector<4x4xf16>
417424 // CHECK: vector.store %[[V2]], %[[ARG0]][%[[C2]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
418425 // CHECK: %[[V3:.*]] = vector.extract %[[ARG1]][3] : vector<4xf16> from vector<4x4xf16>
419426 // CHECK: vector.store %[[V3]], %[[ARG0]][%[[C3]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
420- %c0 = arith.constant 0 : index
421- vector.store %arg1 , %arg0 [%c0 , %c0 ] : memref <4 x4 xf16 >, vector <4 x4 xf16 >
422- return
427+
428+
429+ func.func @vector_load_4D_to_2D (%mem: memref <4 x4 x4 x4 xf16 >) -> vector <2 x2 xf16 > {
430+ %c1 = arith.constant 1 : index
431+ %0 = vector.load %mem [%c1 , %c1 , %c1 , %c1 ] : memref <4 x4 x4 x4 xf16 >, vector <2 x2 xf16 >
432+ return %0 : vector <2 x2 xf16 >
423433}
424434
425- // CHECK-LABEL: func.func @unroll_vector_load (
435+ // CHECK-LABEL: func.func @vector_load_4D_to_2D (
426436// CHECK-SAME: %[[ARG:.*]]: memref<4x4x4x4xf16>) -> vector<2x2xf16> {
427- func.func @unroll_vector_load (%arg0: memref <4 x4 x4 x4 xf16 >) -> vector <2 x2 xf16 > {
428437 // CHECK: %[[C2:.*]] = arith.constant 2 : index
429438 // CHECK: %[[C1:.*]] = arith.constant 1 : index
430439 // CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<2x2xf16>
@@ -433,21 +442,18 @@ func.func @unroll_vector_load(%arg0: memref<4x4x4x4xf16>) -> vector<2x2xf16> {
433442 // CHECK: %[[V2:.*]] = vector.load %[[ARG]][%[[C1]], %[[C1]], %[[C2]], %[[C1]]] : memref<4x4x4x4xf16>, vector<2xf16>
434443 // CHECK: %[[V3:.*]] = vector.insert_strided_slice %[[V2]], %[[V1]] {offsets = [1, 0], strides = [1]} : vector<2xf16> into vector<2x2xf16>
435444 // CHECK: return %[[V3]] : vector<2x2xf16>
445+
446+ func.func @vector_store_2D_to_4D (%mem: memref <4 x4 x4 x4 xf16 >, %v: vector <2 x2 xf16 >) {
436447 %c1 = arith.constant 1 : index
437- %0 = vector.load %arg0 [%c1 , %c1 , %c1 , %c1 ] : memref <4 x4 x4 x4 xf16 >, vector <2 x2 xf16 >
438- return %0 : vector < 2 x 2 x f16 >
448+ vector.store %v , %mem [%c1 , %c1 , %c1 , %c1 ] : memref <4 x4 x4 x4 xf16 >, vector <2 x2 xf16 >
449+ return
439450}
440451
441- // CHECK-LABEL: func.func @unroll_vector_store (
452+ // CHECK-LABEL: func.func @vector_store_2D_to_4D (
442453// CHECK-SAME: %[[ARG0:.*]]: memref<4x4x4x4xf16>, %[[ARG1:.*]]: vector<2x2xf16>) {
443- func.func @unroll_vector_store (%arg0: memref <4 x4 x4 x4 xf16 >, %arg1: vector <2 x2 xf16 >) {
444454 // CHECK: %[[C2:.*]] = arith.constant 2 : index
445455 // CHECK: %[[C1:.*]] = arith.constant 1 : index
446456 // CHECK: %[[V0:.*]] = vector.extract %[[ARG1]][0] : vector<2xf16> from vector<2x2xf16>
447457 // CHECK: vector.store %[[V0]], %[[ARG0]][%[[C1]], %[[C1]], %[[C1]], %[[C1]]] : memref<4x4x4x4xf16>, vector<2xf16>
448458 // CHECK: %[[V1:.*]] = vector.extract %[[ARG1]][1] : vector<2xf16> from vector<2x2xf16>
449459 // CHECK: vector.store %[[V1]], %[[ARG0]][%[[C1]], %[[C1]], %[[C2]], %[[C1]]] : memref<4x4x4x4xf16>, vector<2xf16>
450- %c1 = arith.constant 1 : index
451- vector.store %arg1 , %arg0 [%c1 , %c1 , %c1 , %c1 ] : memref <4 x4 x4 x4 xf16 >, vector <2 x2 xf16 >
452- return
453- }
0 commit comments