diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp index 4623b9667998c..64a9ad8e9bade 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp @@ -61,8 +61,8 @@ struct ConvertVectorToLLVMPass } // namespace void ConvertVectorToLLVMPass::runOnOperation() { - // Perform progressive lowering of operations on slices and - // all contraction operations. Also applies folding and DCE. + // Perform progressive lowering of operations on slices and all contraction + // operations. Also materializes masks, applies folding and DCE. { RewritePatternSet patterns(&getContext()); populateVectorToVectorCanonicalizationPatterns(patterns); @@ -76,6 +76,8 @@ void ConvertVectorToLLVMPass::runOnOperation() { VectorTransformsOptions()); // Vector transfer ops with rank > 1 should be lowered with VectorToSCF. populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1); + populateVectorMaskMaterializationPatterns(patterns, + force32BitVectorIndices); (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); } @@ -83,7 +85,6 @@ void ConvertVectorToLLVMPass::runOnOperation() { LowerToLLVMOptions options(&getContext()); LLVMTypeConverter converter(&getContext(), options); RewritePatternSet patterns(&getContext()); - populateVectorMaskMaterializationPatterns(patterns, force32BitVectorIndices); populateVectorTransferLoweringPatterns(patterns); populateVectorToLLVMMatrixConversionPatterns(converter, patterns); populateVectorToLLVMConversionPatterns( diff --git a/mlir/test/Conversion/VectorToLLVM/vector-mask-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-mask-to-llvm.mlir index 82351eb7c98a4..91e5358622b69 100644 --- a/mlir/test/Conversion/VectorToLLVM/vector-mask-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-mask-to-llvm.mlir @@ -7,7 +7,7 @@ // CMP32: %[[T1:.*]] = arith.index_cast %[[ARG]] : index to i32 // CMP32: %[[T2:.*]] = llvm.insertelement %[[T1]], %{{.*}}[%{{.*}} : i32] : vector<11xi32> // CMP32: %[[T3:.*]] = llvm.shufflevector %[[T2]], %{{.*}} [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] : vector<11xi32> -// CMP32: %[[T4:.*]] = arith.cmpi slt, %[[T0]], %[[T3]] : vector<11xi32> +// CMP32: %[[T4:.*]] = arith.cmpi sgt, %[[T3]], %[[T0]] : vector<11xi32> // CMP32: return %[[T4]] : vector<11xi1> // CMP64-LABEL: @genbool_var_1d( @@ -16,7 +16,7 @@ // CMP64: %[[T1:.*]] = arith.index_cast %[[ARG]] : index to i64 // CMP64: %[[T2:.*]] = llvm.insertelement %[[T1]], %{{.*}}[%{{.*}} : i32] : vector<11xi64> // CMP64: %[[T3:.*]] = llvm.shufflevector %[[T2]], %{{.*}} [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] : vector<11xi64> -// CMP64: %[[T4:.*]] = arith.cmpi slt, %[[T0]], %[[T3]] : vector<11xi64> +// CMP64: %[[T4:.*]] = arith.cmpi sgt, %[[T3]], %[[T0]] : vector<11xi64> // CMP64: return %[[T4]] : vector<11xi1> func.func @genbool_var_1d(%arg0: index) -> vector<11xi1> { diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir index 2473fe933ffcb..ea88fece9e662 100644 --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -3097,7 +3097,7 @@ func.func @create_mask_0d(%num_elems : index) -> vector { // CHECK: %[[NUM_ELEMS_i32:.*]] = arith.index_cast %[[NUM_ELEMS]] : index to i32 // CHECK: %[[BOUNDS:.*]] = llvm.insertelement %[[NUM_ELEMS_i32]] // CHECK: %[[BOUNDS_CAST:.*]] = builtin.unrealized_conversion_cast %[[BOUNDS]] : vector<1xi32> to vector -// CHECK: %[[RESULT:.*]] = arith.cmpi slt, %[[INDICES]], %[[BOUNDS_CAST]] : vector +// CHECK: %[[RESULT:.*]] = arith.cmpi sgt, %[[BOUNDS_CAST]], %[[INDICES]] : vector // CHECK: return %[[RESULT]] : vector // ----- @@ -3113,7 +3113,7 @@ func.func @create_mask_1d(%num_elems : index) -> vector<4xi1> { // CHECK: %[[NUM_ELEMS_i32:.*]] = arith.index_cast %[[NUM_ELEMS]] : index to i32 // CHECK: %[[BOUNDS_INSERT:.*]] = llvm.insertelement %[[NUM_ELEMS_i32]] // CHECK: %[[BOUNDS:.*]] = llvm.shufflevector %[[BOUNDS_INSERT]] -// CHECK: %[[RESULT:.*]] = arith.cmpi slt, %[[INDICES]], %[[BOUNDS]] : vector<4xi32> +// CHECK: %[[RESULT:.*]] = arith.cmpi sgt, %[[BOUNDS]], %[[INDICES]] : vector<4xi32> // CHECK: return %[[RESULT]] : vector<4xi1> // ----- diff --git a/mlir/test/Conversion/VectorToLLVM/vector-xfer-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-xfer-to-llvm.mlir index 8f01cc2b8d44c..d3f6d7eca90b4 100644 --- a/mlir/test/Conversion/VectorToLLVM/vector-xfer-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-xfer-to-llvm.mlir @@ -14,30 +14,28 @@ func.func @transfer_read_write_1d(%A : memref, %base: index) -> vector<17 // CHECK-LABEL: func @transfer_read_write_1d // CHECK-SAME: %[[MEM:.*]]: memref, // CHECK-SAME: %[[BASE:.*]]: index) -> vector<17xf32> -// CHECK: %[[C7:.*]] = arith.constant 7.0 -// -// 1. Let dim be the memref dimension, compute the in-bound index (dim - offset) -// CHECK: %[[C0:.*]] = arith.constant 0 : index -// CHECK: %[[DIM:.*]] = memref.dim %[[MEM]], %[[C0]] : memref -// CHECK: %[[BOUND:.*]] = arith.subi %[[DIM]], %[[BASE]] : index +// 1. Create pass-through vector. +// CHECK-DAG: %[[PASS_THROUGH:.*]] = arith.constant dense<7.000000e+00> : vector<17xf32> // // 2. Create a vector with linear indices [ 0 .. vector_length - 1 ]. -// CHECK: %[[linearIndex:.*]] = arith.constant dense +// CHECK-DAG: %[[linearIndex:.*]] = arith.constant dense // CHECK-SAME: <[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]> : vector<17x[[$IDX_TYPE]]> // -// 3. Create bound vector to compute in-bound mask: +// 3. Let dim be the memref dimension, compute the in-bound index (dim - offset) +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[DIM:.*]] = memref.dim %[[MEM]], %[[C0]] : memref +// CHECK: %[[BOUND:.*]] = arith.subi %[[DIM]], %[[BASE]] : index +// +// 4. Create bound vector to compute in-bound mask: // [ 0 .. vector_length - 1 ] < [ dim - offset .. dim - offset ] // CHECK: %[[btrunc:.*]] = arith.index_cast %[[BOUND]] : // CMP32-SAME: index to i32 // CMP64-SAME: index to i64 // CHECK: %[[boundVecInsert:.*]] = llvm.insertelement %[[btrunc]] // CHECK: %[[boundVect:.*]] = llvm.shufflevector %[[boundVecInsert]] -// CHECK: %[[mask:.*]] = arith.cmpi slt, %[[linearIndex]], %[[boundVect]] : vector<17x[[$IDX_TYPE]]> +// CHECK: %[[mask:.*]] = arith.cmpi sgt, %[[boundVect]], %[[linearIndex]] : vector<17x[[$IDX_TYPE]]> // CMP64-SAME: : vector<17xi64> // -// 4. Create pass-through vector. -// CHECK: %[[PASS_THROUGH:.*]] = arith.constant dense<7.{{.*}}> : vector<17xf32> -// // 5. Bitcast to vector form. // CHECK: %[[gep:.*]] = llvm.getelementptr %{{.*}} : // CHECK-SAME: (!llvm.ptr, i64) -> !llvm.ptr, f32 @@ -48,28 +46,23 @@ func.func @transfer_read_write_1d(%A : memref, %base: index) -> vector<17 // CHECK-SAME: -> vector<17xf32> // // 1. Let dim be the memref dimension, compute the in-bound index (dim - offset) -// CHECK: %[[C0_b:.*]] = arith.constant 0 : index -// CHECK: %[[DIM_b:.*]] = memref.dim %[[MEM]], %[[C0_b]] : memref +// CHECK: %[[DIM_b:.*]] = memref.dim %[[MEM]], %[[C0]] : memref // CHECK: %[[BOUND_b:.*]] = arith.subi %[[DIM_b]], %[[BASE]] : index // -// 2. Create a vector with linear indices [ 0 .. vector_length - 1 ]. -// CHECK: %[[linearIndex_b:.*]] = arith.constant dense -// CHECK-SAME: <[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]> : vector<17x[[$IDX_TYPE]]> -// -// 3. Create bound vector to compute in-bound mask: +// 2. Create bound vector to compute in-bound mask: // [ 0 .. vector_length - 1 ] < [ dim - offset .. dim - offset ] // CHECK: %[[btrunc_b:.*]] = arith.index_cast %[[BOUND_b]] // CMP32-SAME: index to i32 // CHECK: %[[boundVecInsert_b:.*]] = llvm.insertelement %[[btrunc_b]] // CHECK: %[[boundVect_b:.*]] = llvm.shufflevector %[[boundVecInsert_b]] -// CHECK: %[[mask_b:.*]] = arith.cmpi slt, %[[linearIndex_b]], -// CHECK-SAME: %[[boundVect_b]] : vector<17x[[$IDX_TYPE]]> +// CHECK: %[[mask_b:.*]] = arith.cmpi sgt, %[[boundVect_b]], +// CHECK-SAME: %[[linearIndex]] : vector<17x[[$IDX_TYPE]]> // -// 4. Bitcast to vector form. +// 3. Bitcast to vector form. // CHECK: %[[gep_b:.*]] = llvm.getelementptr {{.*}} : // CHECK-SAME: (!llvm.ptr, i64) -> !llvm.ptr, f32 // -// 5. Rewrite as a masked write. +// 4. Rewrite as a masked write. // CHECK: llvm.intr.masked.store %[[loaded]], %[[gep_b]], %[[mask_b]] // CHECK-SAME: {alignment = 4 : i32} : // CHECK-SAME: vector<17xf32>, vector<17xi1> into !llvm.ptr @@ -87,17 +80,18 @@ func.func @transfer_read_write_1d_scalable(%A : memref, %base: index) -> // CHECK-LABEL: func @transfer_read_write_1d_scalable // CHECK-SAME: %[[MEM:.*]]: memref, // CHECK-SAME: %[[BASE:.*]]: index) -> vector<[17]xf32> -// CHECK: %[[C7:.*]] = arith.constant 7.0 +// 1. Create pass-through vector. +// CHECK-DAG: %[[PASS_THROUGH:.*]] = arith.constant dense<7.000000e+00> : vector<[17]xf32> // -// 1. Let dim be the memref dimension, compute the in-bound index (dim - offset) -// CHECK: %[[C0:.*]] = arith.constant 0 : index +// 2. Let dim be the memref dimension, compute the in-bound index (dim - offset) +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index // CHECK: %[[DIM:.*]] = memref.dim %[[MEM]], %[[C0]] : memref // CHECK: %[[BOUND:.*]] = arith.subi %[[DIM]], %[[BASE]] : index // -// 2. Create a vector with linear indices [ 0 .. vector_length - 1 ]. +// 3. Create a vector with linear indices [ 0 .. vector_length - 1 ]. // CHECK: %[[linearIndex:.*]] = llvm.intr.stepvector : vector<[17]x[[$IDX_TYPE]]> // -// 3. Create bound vector to compute in-bound mask: +// 4. Create bound vector to compute in-bound mask: // [ 0 .. vector_length - 1 ] < [ dim - offset .. dim - offset ] // CHECK: %[[btrunc:.*]] = arith.index_cast %[[BOUND]] : index to [[$IDX_TYPE]] // CHECK: %[[boundVecInsert:.*]] = llvm.insertelement %[[btrunc]] @@ -105,9 +99,6 @@ func.func @transfer_read_write_1d_scalable(%A : memref, %base: index) -> // CHECK: %[[mask:.*]] = arith.cmpi slt, %[[linearIndex]], %[[boundVect]] // CHECK-SAME: : vector<[17]x[[$IDX_TYPE]]> // -// 4. Create pass-through vector. -// CHECK: %[[PASS_THROUGH:.*]] = arith.constant dense<7.{{.*}}> : vector<[17]xf32> -// // 5. Bitcast to vector form. // CHECK: %[[gep:.*]] = llvm.getelementptr %{{.*}} : // CHECK-SAME: (!llvm.ptr, i64) -> !llvm.ptr, f32 @@ -118,8 +109,7 @@ func.func @transfer_read_write_1d_scalable(%A : memref, %base: index) -> // CHECK-SAME: -> vector<[17]xf32> // // 1. Let dim be the memref dimension, compute the in-bound index (dim - offset) -// CHECK: %[[C0_b:.*]] = arith.constant 0 : index -// CHECK: %[[DIM_b:.*]] = memref.dim %[[MEM]], %[[C0_b]] : memref +// CHECK: %[[DIM_b:.*]] = memref.dim %[[MEM]], %[[C0]] : memref // CHECK: %[[BOUND_b:.*]] = arith.subi %[[DIM_b]], %[[BASE]] : index // // 2. Create a vector with linear indices [ 0 .. vector_length - 1 ]. @@ -197,23 +187,23 @@ func.func @transfer_read_2d_to_1d(%A : memref, %base0: index, %base1: i } // CHECK-LABEL: func @transfer_read_2d_to_1d // CHECK-SAME: %[[BASE_0:[a-zA-Z0-9]*]]: index, %[[BASE_1:[a-zA-Z0-9]*]]: index) -> vector<17xf32> -// CHECK: %[[c1:.*]] = arith.constant 1 : index +// +// Create a vector with linear indices [ 0 .. vector_length - 1 ]. +// CHECK-DAG: %[[linearIndex:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]> : +// CHECK-SAME: vector<17x[[$IDX_TYPE]]> +// +// CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index // CHECK: %[[DIM:.*]] = memref.dim %{{.*}}, %[[c1]] : memref // // Compute the in-bound index (dim - offset) // CHECK: %[[BOUND:.*]] = arith.subi %[[DIM]], %[[BASE_1]] : index // -// Create a vector with linear indices [ 0 .. vector_length - 1 ]. -// CHECK: %[[linearIndex:.*]] = arith.constant dense -// CHECK-SAME: <[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]> : -// CHECK-SAME: vector<17x[[$IDX_TYPE]]> -// // Create bound vector to compute in-bound mask: // [ 0 .. vector_length - 1 ] < [ dim - offset .. dim - offset ] // CHECK: %[[btrunc:.*]] = arith.index_cast %[[BOUND]] : index to [[$IDX_TYPE]] // CHECK: %[[boundVecInsert:.*]] = llvm.insertelement %[[btrunc]] // CHECK: %[[boundVect:.*]] = llvm.shufflevector %[[boundVecInsert]] -// CHECK: %[[mask:.*]] = arith.cmpi slt, %[[linearIndex]], %[[boundVect]] +// CHECK: %[[mask:.*]] = arith.cmpi sgt, %[[boundVect]], %[[linearIndex]] func.func @transfer_read_2d_to_1d_scalable(%A : memref, %base0: index, %base1: index) -> vector<[17]xf32> { %f7 = arith.constant 7.0: f32 @@ -255,12 +245,13 @@ func.func @transfer_read_write_1d_non_zero_addrspace(%A : memref, %bas // CHECK-LABEL: func @transfer_read_write_1d_non_zero_addrspace // CHECK-SAME: %[[BASE:[a-zA-Z0-9]*]]: index) -> vector<17xf32> // +// CHECK: %[[c0:.*]] = arith.constant 0 : index +// // 1. Check address space for GEP is correct. // CHECK: %[[gep:.*]] = llvm.getelementptr {{.*}} : // CHECK-SAME: (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, f32 // // 2. Check address space of the memref is correct. -// CHECK: %[[c0:.*]] = arith.constant 0 : index // CHECK: %[[DIM:.*]] = memref.dim %{{.*}}, %[[c0]] : memref // // 3. Check address space for GEP is correct. @@ -280,12 +271,13 @@ func.func @transfer_read_write_1d_non_zero_addrspace_scalable(%A : memref vector<[17]xf32> // +// CHECK: %[[c0:.*]] = arith.constant 0 : index +// // 1. Check address space for GEP is correct. // CHECK: %[[gep:.*]] = llvm.getelementptr {{.*}} : // CHECK-SAME: (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, f32 // // 2. Check address space of the memref is correct. -// CHECK: %[[c0:.*]] = arith.constant 0 : index // CHECK: %[[DIM:.*]] = memref.dim %{{.*}}, %[[c0]] : memref // // 3. Check address space for GEP is correct. @@ -330,10 +322,10 @@ func.func @transfer_read_1d_inbounds_scalable(%A : memref, %base: index) // CHECK-LABEL: func @transfer_read_write_1d_mask // CHECK: %[[mask1:.*]] = arith.constant dense<[false, false, true, false, true]> -// CHECK: %[[cmpi:.*]] = arith.cmpi slt +// CHECK: %[[cmpi:.*]] = arith.cmpi sgt // CHECK: %[[mask2:.*]] = arith.andi %[[cmpi]], %[[mask1]] // CHECK: %[[r:.*]] = llvm.intr.masked.load %{{.*}}, %[[mask2]] -// CHECK: %[[cmpi_1:.*]] = arith.cmpi slt +// CHECK: %[[cmpi_1:.*]] = arith.cmpi sgt // CHECK: %[[mask3:.*]] = arith.andi %[[cmpi_1]], %[[mask1]] // CHECK: llvm.intr.masked.store %[[r]], %{{.*}}, %[[mask3]] // CHECK: return %[[r]]