@@ -2377,7 +2377,7 @@ func.func @matrix_ops_index(%A: vector<64xindex>, %B: vector<48xindex>) -> vecto
23772377
23782378// -----
23792379
2380- func.func @transfer_read_1d (%A : memref <?xf32 >, %base: index ) -> vector <17 xf32 > {
2380+ func.func @transfer_read_write_1d (%A : memref <?xf32 >, %base: index ) -> vector <17 xf32 > {
23812381 %f7 = arith.constant 7.0 : f32
23822382 %f = vector.transfer_read %A [%base ], %f7
23832383 {permutation_map = affine_map <(d0 ) -> (d0 )>} :
@@ -2387,7 +2387,7 @@ func.func @transfer_read_1d(%A : memref<?xf32>, %base: index) -> vector<17xf32>
23872387 vector <17 xf32 >, memref <?xf32 >
23882388 return %f: vector <17 xf32 >
23892389}
2390- // CHECK-LABEL: func @transfer_read_1d
2390+ // CHECK-LABEL: func @transfer_read_write_1d
23912391// CHECK-SAME: %[[MEM:.*]]: memref<?xf32>,
23922392// CHECK-SAME: %[[BASE:.*]]: index) -> vector<17xf32>
23932393// CHECK: %[[C7:.*]] = arith.constant 7.0
@@ -2449,9 +2449,77 @@ func.func @transfer_read_1d(%A : memref<?xf32>, %base: index) -> vector<17xf32>
24492449// CHECK-SAME: {alignment = 4 : i32} :
24502450// CHECK-SAME: vector<17xf32>, vector<17xi1> into !llvm.ptr
24512451
2452+ func.func @transfer_read_write_1d_scalable (%A : memref <?xf32 >, %base: index ) -> vector <[17 ]xf32 > {
2453+ %f7 = arith.constant 7.0 : f32
2454+ %f = vector.transfer_read %A [%base ], %f7
2455+ {permutation_map = affine_map <(d0 ) -> (d0 )>} :
2456+ memref <?xf32 >, vector <[17 ]xf32 >
2457+ vector.transfer_write %f , %A [%base ]
2458+ {permutation_map = affine_map <(d0 ) -> (d0 )>} :
2459+ vector <[17 ]xf32 >, memref <?xf32 >
2460+ return %f: vector <[17 ]xf32 >
2461+ }
2462+ // CHECK-LABEL: func @transfer_read_write_1d_scalable
2463+ // CHECK-SAME: %[[MEM:.*]]: memref<?xf32>,
2464+ // CHECK-SAME: %[[BASE:.*]]: index) -> vector<[17]xf32>
2465+ // CHECK: %[[C7:.*]] = arith.constant 7.0
2466+ //
2467+ // 1. Let dim be the memref dimension, compute the in-bound index (dim - offset)
2468+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
2469+ // CHECK: %[[DIM:.*]] = memref.dim %[[MEM]], %[[C0]] : memref<?xf32>
2470+ // CHECK: %[[BOUND:.*]] = arith.subi %[[DIM]], %[[BASE]] : index
2471+ //
2472+ // 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
2473+ // CHECK: %[[linearIndex:.*]] = llvm.intr.stepvector : vector<[17]xi32>
2474+ //
2475+ // 3. Create bound vector to compute in-bound mask:
2476+ // [ 0 .. vector_length - 1 ] < [ dim - offset .. dim - offset ]
2477+ // CHECK: %[[btrunc:.*]] = arith.index_cast %[[BOUND]] : index to i32
2478+ // CHECK: %[[boundVecInsert:.*]] = llvm.insertelement %[[btrunc]]
2479+ // CHECK: %[[boundVect:.*]] = llvm.shufflevector %[[boundVecInsert]]
2480+ // CHECK: %[[mask:.*]] = arith.cmpi slt, %[[linearIndex]], %[[boundVect]]
2481+ // CHECK-SAME: : vector<[17]xi32>
2482+ //
2483+ // 4. Create pass-through vector.
2484+ // CHECK: %[[PASS_THROUGH:.*]] = arith.constant dense<7.{{.*}}> : vector<[17]xf32>
2485+ //
2486+ // 5. Bitcast to vector form.
2487+ // CHECK: %[[gep:.*]] = llvm.getelementptr %{{.*}} :
2488+ // CHECK-SAME: (!llvm.ptr, i64) -> !llvm.ptr, f32
2489+ //
2490+ // 6. Rewrite as a masked read.
2491+ // CHECK: %[[loaded:.*]] = llvm.intr.masked.load %[[gep]], %[[mask]],
2492+ // CHECK-SAME: %[[PASS_THROUGH]] {alignment = 4 : i32} :
2493+ // CHECK-SAME: -> vector<[17]xf32>
2494+ //
2495+ // 1. Let dim be the memref dimension, compute the in-bound index (dim - offset)
2496+ // CHECK: %[[C0_b:.*]] = arith.constant 0 : index
2497+ // CHECK: %[[DIM_b:.*]] = memref.dim %[[MEM]], %[[C0_b]] : memref<?xf32>
2498+ // CHECK: %[[BOUND_b:.*]] = arith.subi %[[DIM_b]], %[[BASE]] : index
2499+ //
2500+ // 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
2501+ // CHECK: %[[linearIndex_b:.*]] = llvm.intr.stepvector : vector<[17]xi32>
2502+ //
2503+ // 3. Create bound vector to compute in-bound mask:
2504+ // [ 0 .. vector_length - 1 ] < [ dim - offset .. dim - offset ]
2505+ // CHECK: %[[btrunc_b:.*]] = arith.index_cast %[[BOUND_b]] : index to i32
2506+ // CHECK: %[[boundVecInsert_b:.*]] = llvm.insertelement %[[btrunc_b]]
2507+ // CHECK: %[[boundVect_b:.*]] = llvm.shufflevector %[[boundVecInsert_b]]
2508+ // CHECK: %[[mask_b:.*]] = arith.cmpi slt, %[[linearIndex_b]],
2509+ // CHECK-SAME: %[[boundVect_b]] : vector<[17]xi32>
2510+ //
2511+ // 4. Bitcast to vector form.
2512+ // CHECK: %[[gep_b:.*]] = llvm.getelementptr {{.*}} :
2513+ // CHECK-SAME: (!llvm.ptr, i64) -> !llvm.ptr, f32
2514+ //
2515+ // 5. Rewrite as a masked write.
2516+ // CHECK: llvm.intr.masked.store %[[loaded]], %[[gep_b]], %[[mask_b]]
2517+ // CHECK-SAME: {alignment = 4 : i32} :
2518+ // CHECK-SAME: vector<[17]xf32>, vector<[17]xi1> into !llvm.ptr
2519+
24522520// -----
24532521
2454- func.func @transfer_read_index_1d (%A : memref <?xindex >, %base: index ) -> vector <17 xindex > {
2522+ func.func @transfer_read_write_index_1d (%A : memref <?xindex >, %base: index ) -> vector <17 xindex > {
24552523 %f7 = arith.constant 7 : index
24562524 %f = vector.transfer_read %A [%base ], %f7
24572525 {permutation_map = affine_map <(d0 ) -> (d0 )>} :
@@ -2461,7 +2529,7 @@ func.func @transfer_read_index_1d(%A : memref<?xindex>, %base: index) -> vector<
24612529 vector <17 xindex >, memref <?xindex >
24622530 return %f: vector <17 xindex >
24632531}
2464- // CHECK-LABEL: func @transfer_read_index_1d
2532+ // CHECK-LABEL: func @transfer_read_write_index_1d
24652533// CHECK-SAME: %[[BASE:[a-zA-Z0-9]*]]: index) -> vector<17xindex>
24662534// CHECK: %[[SPLAT:.*]] = arith.constant dense<7> : vector<17xindex>
24672535// CHECK: %{{.*}} = builtin.unrealized_conversion_cast %[[SPLAT]] : vector<17xindex> to vector<17xi64>
@@ -2472,6 +2540,27 @@ func.func @transfer_read_index_1d(%A : memref<?xindex>, %base: index) -> vector<
24722540// CHECK: llvm.intr.masked.store %[[loaded]], %{{.*}}, %{{.*}} {alignment = 8 : i32} :
24732541// CHECK-SAME: vector<17xi64>, vector<17xi1> into !llvm.ptr
24742542
2543+ func.func @transfer_read_write_index_1d_scalable (%A : memref <?xindex >, %base: index ) -> vector <[17 ]xindex > {
2544+ %f7 = arith.constant 7 : index
2545+ %f = vector.transfer_read %A [%base ], %f7
2546+ {permutation_map = affine_map <(d0 ) -> (d0 )>} :
2547+ memref <?xindex >, vector <[17 ]xindex >
2548+ vector.transfer_write %f , %A [%base ]
2549+ {permutation_map = affine_map <(d0 ) -> (d0 )>} :
2550+ vector <[17 ]xindex >, memref <?xindex >
2551+ return %f: vector <[17 ]xindex >
2552+ }
2553+ // CHECK-LABEL: func @transfer_read_write_index_1d
2554+ // CHECK-SAME: %[[BASE:[a-zA-Z0-9]*]]: index) -> vector<[17]xindex>
2555+ // CHECK: %[[SPLAT:.*]] = arith.constant dense<7> : vector<[17]xindex>
2556+ // CHECK: %{{.*}} = builtin.unrealized_conversion_cast %[[SPLAT]] : vector<[17]xindex> to vector<[17]xi64>
2557+
2558+ // CHECK: %[[loaded:.*]] = llvm.intr.masked.load %{{.*}}, %{{.*}}, %{{.*}} {alignment = 8 : i32} :
2559+ // CHECK-SAME: (!llvm.ptr, vector<[17]xi1>, vector<[17]xi64>) -> vector<[17]xi64>
2560+
2561+ // CHECK: llvm.intr.masked.store %[[loaded]], %{{.*}}, %{{.*}} {alignment = 8 : i32} :
2562+ // CHECK-SAME: vector<[17]xi64>, vector<[17]xi1> into !llvm.ptr
2563+
24752564// -----
24762565
24772566func.func @transfer_read_2d_to_1d (%A : memref <?x?xf32 >, %base0: index , %base1: index ) -> vector <17 xf32 > {
@@ -2501,9 +2590,34 @@ func.func @transfer_read_2d_to_1d(%A : memref<?x?xf32>, %base0: index, %base1: i
25012590// CHECK: %[[boundVect:.*]] = llvm.shufflevector %[[boundVecInsert]]
25022591// CHECK: %[[mask:.*]] = arith.cmpi slt, %[[linearIndex]], %[[boundVect]]
25032592
2593+ func.func @transfer_read_2d_to_1d_scalable (%A : memref <?x?xf32 >, %base0: index , %base1: index ) -> vector <[17 ]xf32 > {
2594+ %f7 = arith.constant 7.0 : f32
2595+ %f = vector.transfer_read %A [%base0 , %base1 ], %f7
2596+ {permutation_map = affine_map <(d0 , d1 ) -> (d1 )>} :
2597+ memref <?x?xf32 >, vector <[17 ]xf32 >
2598+ return %f: vector <[17 ]xf32 >
2599+ }
2600+ // CHECK-LABEL: func @transfer_read_2d_to_1d
2601+ // CHECK-SAME: %[[BASE_0:[a-zA-Z0-9]*]]: index, %[[BASE_1:[a-zA-Z0-9]*]]: index) -> vector<[17]xf32>
2602+ // CHECK: %[[c1:.*]] = arith.constant 1 : index
2603+ // CHECK: %[[DIM:.*]] = memref.dim %{{.*}}, %[[c1]] : memref<?x?xf32>
2604+ //
2605+ // Compute the in-bound index (dim - offset)
2606+ // CHECK: %[[BOUND:.*]] = arith.subi %[[DIM]], %[[BASE_1]] : index
2607+ //
2608+ // Create a vector with linear indices [ 0 .. vector_length - 1 ].
2609+ // CHECK: %[[linearIndex:.*]] = llvm.intr.stepvector : vector<[17]xi32>
2610+ //
2611+ // Create bound vector to compute in-bound mask:
2612+ // [ 0 .. vector_length - 1 ] < [ dim - offset .. dim - offset ]
2613+ // CHECK: %[[btrunc:.*]] = arith.index_cast %[[BOUND]] : index to i32
2614+ // CHECK: %[[boundVecInsert:.*]] = llvm.insertelement %[[btrunc]]
2615+ // CHECK: %[[boundVect:.*]] = llvm.shufflevector %[[boundVecInsert]]
2616+ // CHECK: %[[mask:.*]] = arith.cmpi slt, %[[linearIndex]], %[[boundVect]]
2617+
25042618// -----
25052619
2506- func.func @transfer_read_1d_non_zero_addrspace (%A : memref <?xf32 , 3 >, %base: index ) -> vector <17 xf32 > {
2620+ func.func @transfer_read_write_1d_non_zero_addrspace (%A : memref <?xf32 , 3 >, %base: index ) -> vector <17 xf32 > {
25072621 %f7 = arith.constant 7.0 : f32
25082622 %f = vector.transfer_read %A [%base ], %f7
25092623 {permutation_map = affine_map <(d0 ) -> (d0 )>} :
@@ -2513,7 +2627,7 @@ func.func @transfer_read_1d_non_zero_addrspace(%A : memref<?xf32, 3>, %base: ind
25132627 vector <17 xf32 >, memref <?xf32 , 3 >
25142628 return %f: vector <17 xf32 >
25152629}
2516- // CHECK-LABEL: func @transfer_read_1d_non_zero_addrspace
2630+ // CHECK-LABEL: func @transfer_read_write_1d_non_zero_addrspace
25172631// CHECK-SAME: %[[BASE:[a-zA-Z0-9]*]]: index) -> vector<17xf32>
25182632//
25192633// 1. Check address space for GEP is correct.
@@ -2528,6 +2642,31 @@ func.func @transfer_read_1d_non_zero_addrspace(%A : memref<?xf32, 3>, %base: ind
25282642// CHECK: %[[gep_b:.*]] = llvm.getelementptr {{.*}} :
25292643// CHECK-SAME: (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, f32
25302644
2645+ func.func @transfer_read_write_1d_non_zero_addrspace_scalable (%A : memref <?xf32 , 3 >, %base: index ) -> vector <[17 ]xf32 > {
2646+ %f7 = arith.constant 7.0 : f32
2647+ %f = vector.transfer_read %A [%base ], %f7
2648+ {permutation_map = affine_map <(d0 ) -> (d0 )>} :
2649+ memref <?xf32 , 3 >, vector <[17 ]xf32 >
2650+ vector.transfer_write %f , %A [%base ]
2651+ {permutation_map = affine_map <(d0 ) -> (d0 )>} :
2652+ vector <[17 ]xf32 >, memref <?xf32 , 3 >
2653+ return %f: vector <[17 ]xf32 >
2654+ }
2655+ // CHECK-LABEL: func @transfer_read_write_1d_non_zero_addrspace_scalable
2656+ // CHECK-SAME: %[[BASE:[a-zA-Z0-9]*]]: index) -> vector<[17]xf32>
2657+ //
2658+ // 1. Check address space for GEP is correct.
2659+ // CHECK: %[[gep:.*]] = llvm.getelementptr {{.*}} :
2660+ // CHECK-SAME: (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, f32
2661+ //
2662+ // 2. Check address space of the memref is correct.
2663+ // CHECK: %[[c0:.*]] = arith.constant 0 : index
2664+ // CHECK: %[[DIM:.*]] = memref.dim %{{.*}}, %[[c0]] : memref<?xf32, 3>
2665+ //
2666+ // 3. Check address space for GEP is correct.
2667+ // CHECK: %[[gep_b:.*]] = llvm.getelementptr {{.*}} :
2668+ // CHECK-SAME: (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, f32
2669+
25312670// -----
25322671
25332672func.func @transfer_read_1d_inbounds (%A : memref <?xf32 >, %base: index ) -> vector <17 xf32 > {
@@ -2546,51 +2685,71 @@ func.func @transfer_read_1d_inbounds(%A : memref<?xf32>, %base: index) -> vector
25462685// 2. Rewrite as a load.
25472686// CHECK: %[[loaded:.*]] = llvm.load %[[gep]] {alignment = 4 : i64} : !llvm.ptr -> vector<17xf32>
25482687
2688+ func.func @transfer_read_1d_inbounds_scalable (%A : memref <?xf32 >, %base: index ) -> vector <[17 ]xf32 > {
2689+ %f7 = arith.constant 7.0 : f32
2690+ %f = vector.transfer_read %A [%base ], %f7 {in_bounds = [true ]} :
2691+ memref <?xf32 >, vector <[17 ]xf32 >
2692+ return %f: vector <[17 ]xf32 >
2693+ }
2694+ // CHECK-LABEL: func @transfer_read_1d_inbounds_scalable
2695+ // CHECK-SAME: %[[BASE:[a-zA-Z0-9]*]]: index) -> vector<[17]xf32>
2696+ //
2697+ // 1. Bitcast to vector form.
2698+ // CHECK: %[[gep:.*]] = llvm.getelementptr {{.*}} :
2699+ // CHECK-SAME: (!llvm.ptr, i64) -> !llvm.ptr, f32
2700+ //
2701+ // 2. Rewrite as a load.
2702+ // CHECK: %[[loaded:.*]] = llvm.load %[[gep]] {alignment = 4 : i64} : !llvm.ptr -> vector<[17]xf32>
2703+
25492704// -----
25502705
2551- // CHECK-LABEL: func @transfer_read_1d_mask
2706+ // CHECK-LABEL: func @transfer_read_write_1d_mask
25522707// CHECK: %[[mask1:.*]] = arith.constant dense<[false, false, true, false, true]>
25532708// CHECK: %[[cmpi:.*]] = arith.cmpi slt
25542709// CHECK: %[[mask2:.*]] = arith.andi %[[cmpi]], %[[mask1]]
25552710// CHECK: %[[r:.*]] = llvm.intr.masked.load %{{.*}}, %[[mask2]]
2711+ // CHECK: %[[cmpi_1:.*]] = arith.cmpi slt
2712+ // CHECK: %[[mask3:.*]] = arith.andi %[[cmpi_1]], %[[mask1]]
2713+ // CHECK: llvm.intr.masked.store %[[r]], %{{.*}}, %[[mask3]]
25562714// CHECK: return %[[r]]
2557- func.func @transfer_read_1d_mask (%A : memref <?xf32 >, %base : index ) -> vector <5 xf32 > {
2715+ func.func @transfer_read_write_1d_mask (%A : memref <?xf32 >, %base : index ) -> vector <5 xf32 > {
25582716 %m = arith.constant dense <[0 , 0 , 1 , 0 , 1 ]> : vector <5 xi1 >
25592717 %f7 = arith.constant 7.0 : f32
25602718 %f = vector.transfer_read %A [%base ], %f7 , %m : memref <?xf32 >, vector <5 xf32 >
2719+ vector.transfer_write %f , %A [%base ], %m : vector <5 xf32 >, memref <?xf32 >
25612720 return %f: vector <5 xf32 >
25622721}
25632722
2564- // -----
2565-
2566- // CHECK-LABEL: func @transfer_read_1d_scalable_mask
2567- // CHECK: %[[passtru:.*]] = arith.constant dense<0.000000e+00> : vector<[4]xf32>
2568- // CHECK: %[[r:.*]] = llvm.intr.masked.load %{{.*}}, %{{.*}}, %[[passtru]] {alignment = 4 : i32} : (!llvm.ptr, vector<[4]xi1>, vector<[4]xf32>) -> vector<[4]xf32>
2569- // CHECK: return %[[r]] : vector<[4]xf32>
2570- func.func @transfer_read_1d_scalable_mask (%arg0: memref <1 x?xf32 >, %mask: vector <[4 ]xi1 >) -> vector <[4 ]xf32 > {
2571- %c0 = arith.constant 0 : index
2572- %pad = arith.constant 0.0 : f32
2573- %vec = vector.transfer_read %arg0 [%c0 , %c0 ], %pad , %mask {in_bounds = [true ]} : memref <1 x?xf32 >, vector <[4 ]xf32 >
2574- return %vec : vector <[4 ]xf32 >
2723+ // CHECK-LABEL: func @transfer_read_write_1d_mask_scalable
2724+ // CHECK-SAME: %[[mask:[a-zA-Z0-9]*]]: vector<[5]xi1>
2725+ // CHECK: %[[cmpi:.*]] = arith.cmpi slt
2726+ // CHECK: %[[mask1:.*]] = arith.andi %[[cmpi]], %[[mask]]
2727+ // CHECK: %[[r:.*]] = llvm.intr.masked.load %{{.*}}, %[[mask1]]
2728+ // CHECK: %[[cmpi_1:.*]] = arith.cmpi slt
2729+ // CHECK: %[[mask2:.*]] = arith.andi %[[cmpi_1]], %[[mask]]
2730+ // CHECK: llvm.intr.masked.store %[[r]], %{{.*}}, %[[mask2]]
2731+ // CHECK: return %[[r]]
2732+ func.func @transfer_read_write_1d_mask_scalable (%A : memref <?xf32 >, %base : index , %m : vector <[5 ]xi1 >) -> vector <[5 ]xf32 > {
2733+ %f7 = arith.constant 7.0 : f32
2734+ %f = vector.transfer_read %A [%base ], %f7 , %m : memref <?xf32 >, vector <[5 ]xf32 >
2735+ vector.transfer_write %f , %A [%base ], %m : vector <[5 ]xf32 >, memref <?xf32 >
2736+ return %f: vector <[5 ]xf32 >
25752737}
25762738
25772739// -----
2578- // CHECK-LABEL: func @transfer_write_1d_scalable_mask
2579- // CHECK: llvm.intr.masked.store %{{.*}}, %{{.*}}, %{{.*}} {alignment = 4 : i32} : vector<[4]xf32>, vector<[4]xi1> into !llvm.ptr
2580- func.func @transfer_write_1d_scalable_mask (%arg0: memref <1 x?xf32 >, %vec: vector <[4 ]xf32 >, %mask: vector <[4 ]xi1 >) {
2581- %c0 = arith.constant 0 : index
2582- vector.transfer_write %vec , %arg0 [%c0 , %c0 ], %mask {in_bounds = [true ]} : vector <[4 ]xf32 >, memref <1 x?xf32 >
2583- return
2584- }
25852740
2586- // -----
2741+ // Can't lower xfer_read/xfer_write on tensors, but this shouldn't crash
25872742
2588- // CHECK-LABEL: func @transfer_write_tensor
2743+ // CHECK-LABEL: func @transfer_read_write_tensor
2744+ // CHECK: vector.transfer_read
25892745// CHECK: vector.transfer_write
2590- func.func @transfer_write_tensor (%arg0: vector <4 xf32 >,%arg1: tensor <?xf32 >) -> tensor <?xf32 > {
2591- %c0 = arith.constant 0 : index
2592- %0 = vector.transfer_write %arg0 , %arg1 [%c0 ] : vector <4 xf32 >, tensor <?xf32 >
2593- return %0 : tensor <?xf32 >
2746+ func.func @transfer_read_write_tensor (%A: tensor <?xf32 >, %base : index ) -> vector <4 xf32 > {
2747+ %f7 = arith.constant 7.0 : f32
2748+ %c0 = arith.constant 0 : index
2749+ %f = vector.transfer_read %A [%base ], %f7 : tensor <?xf32 >, vector <4 xf32 >
2750+ %w = vector.transfer_write %f , %A [%c0 ] : vector <4 xf32 >, tensor <?xf32 >
2751+ " test.some_use" (%w ) : (tensor <?xf32 >) -> ()
2752+ return %f : vector <4 xf32 >
25942753}
25952754
25962755// -----
0 commit comments