diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index eeedf68a1df7c..b3b8afdd8b4c1 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -1729,18 +1729,18 @@ def Vector_LoadOp : Vector_Op<"load", [ "Value":$base, "ValueRange":$indices, CArg<"bool", "false">:$nontemporal, - CArg<"llvm::Align", "llvm::Align()">:$alignment), [{ + CArg<"llvm::MaybeAlign", "llvm::MaybeAlign()">:$alignment), [{ return build($_builder, $_state, resultType, base, indices, nontemporal, - alignment != llvm::Align() ? $_builder.getI64IntegerAttr(alignment.value()) : + alignment.has_value() ? $_builder.getI64IntegerAttr(alignment->value()) : nullptr); }]>, OpBuilder<(ins "TypeRange":$resultTypes, "Value":$base, "ValueRange":$indices, CArg<"bool", "false">:$nontemporal, - CArg<"llvm::Align", "llvm::Align()">:$alignment), [{ + CArg<"llvm::MaybeAlign", "llvm::MaybeAlign()">:$alignment), [{ return build($_builder, $_state, resultTypes, base, indices, nontemporal, - alignment != llvm::Align() ? $_builder.getI64IntegerAttr(alignment.value()) : + alignment.has_value() ? $_builder.getI64IntegerAttr(alignment->value()) : nullptr); }]> ]; @@ -1847,9 +1847,9 @@ def Vector_StoreOp : Vector_Op<"store", [ "Value":$base, "ValueRange":$indices, CArg<"bool", "false">:$nontemporal, - CArg<"llvm::Align", "llvm::Align()">:$alignment), [{ + CArg<"llvm::MaybeAlign", "llvm::MaybeAlign()">:$alignment), [{ return build($_builder, $_state, valueToStore, base, indices, nontemporal, - alignment != llvm::Align() ? $_builder.getI64IntegerAttr(alignment.value()) : + alignment.has_value() ? $_builder.getI64IntegerAttr(alignment->value()) : nullptr); }]> ]; @@ -1876,7 +1876,9 @@ def Vector_MaskedLoadOp : Arguments<(ins Arg:$base, Variadic:$indices, VectorOfNonZeroRankOf<[I1]>:$mask, - AnyVectorOfNonZeroRank:$pass_thru)>, + AnyVectorOfNonZeroRank:$pass_thru, + ConfinedAttr, + [AllAttrOf<[IntPositive, IntPowerOf2]>]>:$alignment)>, Results<(outs AnyVectorOfNonZeroRank:$result)> { let summary = "loads elements from memory into a vector as defined by a mask vector"; @@ -1912,6 +1914,12 @@ def Vector_MaskedLoadOp : %1 = vector.maskedload %base[%i, %j], %mask, %pass_thru : memref, vector<16xi1>, vector<16xf32> into vector<16xf32> ``` + + An optional `alignment` attribute allows to specify the byte alignment of the + load operation. It must be a positive power of 2. The operation must access + memory at an address aligned to this boundary. Violations may lead to + architecture-specific faults or performance penalties. + A value of 0 indicates no specific alignment requirement. }]; let extraClassDeclaration = [{ MemRefType getMemRefType() { @@ -1932,6 +1940,29 @@ def Vector_MaskedLoadOp : let hasCanonicalizer = 1; let hasFolder = 1; let hasVerifier = 1; + + let builders = [ + OpBuilder<(ins "VectorType":$resultType, + "Value":$base, + "ValueRange":$indices, + "Value":$mask, + "Value":$passthrough, + CArg<"llvm::MaybeAlign", "llvm::MaybeAlign()">:$alignment), [{ + return build($_builder, $_state, resultType, base, indices, mask, passthrough, + alignment.has_value() ? $_builder.getI64IntegerAttr(alignment->value()) : + nullptr); + }]>, + OpBuilder<(ins "TypeRange":$resultTypes, + "Value":$base, + "ValueRange":$indices, + "Value":$mask, + "Value":$passthrough, + CArg<"llvm::MaybeAlign", "llvm::MaybeAlign()">:$alignment), [{ + return build($_builder, $_state, resultTypes, base, indices, mask, passthrough, + alignment.has_value() ? $_builder.getI64IntegerAttr(alignment->value()) : + nullptr); + }]> + ]; } def Vector_MaskedStoreOp : @@ -1939,7 +1970,9 @@ def Vector_MaskedStoreOp : Arguments<(ins Arg:$base, Variadic:$indices, VectorOfNonZeroRankOf<[I1]>:$mask, - AnyVectorOfNonZeroRank:$valueToStore)> { + AnyVectorOfNonZeroRank:$valueToStore, + ConfinedAttr, + [AllAttrOf<[IntPositive, IntPowerOf2]>]>:$alignment)> { let summary = "stores elements from a vector into memory as defined by a mask vector"; @@ -1974,6 +2007,12 @@ def Vector_MaskedStoreOp : vector.maskedstore %base[%i, %j], %mask, %value : memref, vector<16xi1>, vector<16xf32> ``` + + An optional `alignment` attribute allows to specify the byte alignment of the + store operation. It must be a positive power of 2. The operation must access + memory at an address aligned to this boundary. Violations may lead to + architecture-specific faults or performance penalties. + A value of 0 indicates no specific alignment requirement. }]; let extraClassDeclaration = [{ MemRefType getMemRefType() { @@ -1992,6 +2031,18 @@ def Vector_MaskedStoreOp : let hasCanonicalizer = 1; let hasFolder = 1; let hasVerifier = 1; + + let builders = [ + OpBuilder<(ins "Value":$base, + "ValueRange":$indices, + "Value":$mask, + "Value":$valueToStore, + CArg<"llvm::MaybeAlign", "llvm::MaybeAlign()">:$alignment), [{ + return build($_builder, $_state, base, indices, mask, valueToStore, + alignment.has_value() ? $_builder.getI64IntegerAttr(alignment->value()) : + nullptr); + }]> + ]; } def Vector_GatherOp : diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index 17a79e3815b97..f9e2a01dbf969 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -247,8 +247,9 @@ class VectorLoadStoreConversion : public ConvertOpToLLVMPattern { MemRefType memRefTy = loadOrStoreOp.getMemRefType(); // Resolve alignment. - unsigned align; - if (failed(getVectorToLLVMAlignment(*this->getTypeConverter(), vectorTy, + unsigned align = loadOrStoreOp.getAlignment().value_or(0); + if (!align && + failed(getVectorToLLVMAlignment(*this->getTypeConverter(), vectorTy, memRefTy, align, useVectorAlignment))) return rewriter.notifyMatchFailure(loadOrStoreOp, "could not resolve alignment"); diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir index 31e17fb3e3cc6..5a424a8ac0d5f 100644 --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir @@ -1679,6 +1679,16 @@ func.func @load_0d(%memref : memref<200x100xf32>, %i : index, %j : index) -> vec // ----- +func.func @load_with_alignment(%memref : memref<200x100xf32>, %i : index, %j : index) -> vector<8xf32> { + %0 = vector.load %memref[%i, %j] { alignment = 8 } : memref<200x100xf32>, vector<8xf32> + return %0 : vector<8xf32> +} + +// CHECK-LABEL: func @load_with_alignment +// CHECK: llvm.load {{.*}} {alignment = 8 : i64} : !llvm.ptr -> vector<8xf32> + +// ----- + //===----------------------------------------------------------------------===// // vector.store //===----------------------------------------------------------------------===// @@ -1785,6 +1795,16 @@ func.func @store_0d(%memref : memref<200x100xf32>, %i : index, %j : index) { // ----- +func.func @store_with_alignment(%memref : memref<200x100xf32>, %i : index, %j : index, %val : vector<4xf32>) { + vector.store %val, %memref[%i, %j] {alignment = 8} : memref<200x100xf32>, vector<4xf32> + return +} + +// CHECK-LABEL: func @store_with_alignment +// CHECK: llvm.store %{{.*}} {alignment = 8 : i64} : vector<4xf32>, !llvm.ptr + +// ----- + //===----------------------------------------------------------------------===// // vector.maskedload //===----------------------------------------------------------------------===// @@ -1839,6 +1859,16 @@ func.func @masked_load_index_scalable(%arg0: memref, %arg1: vector<[16] // ----- +func.func @masked_load_with_alignment(%arg0: memref, %arg1: vector<16xi1>, %arg2: vector<16xf32>, %arg3: index) -> vector<16xf32> { + %0 = vector.maskedload %arg0[%arg3], %arg1, %arg2 { alignment = 2 } : memref, vector<16xi1>, vector<16xf32> into vector<16xf32> + return %0 : vector<16xf32> +} + +// CHECK-LABEL: func @masked_load_with_alignment +// CHECK: llvm.intr.masked.load %{{.*}} {alignment = 2 : i32} : (!llvm.ptr, vector<16xi1>, vector<16xf32>) -> vector<16xf32> + +// ----- + //===----------------------------------------------------------------------===// // vector.maskedstore //===----------------------------------------------------------------------===// @@ -1891,6 +1921,16 @@ func.func @masked_store_index_scalable(%arg0: memref, %arg1: vector<[16 // ----- +func.func @masked_store_with_alignment(%arg0: memref, %arg1: vector<16xi1>, %arg2: vector<16xf32>, %arg3: index) { + vector.maskedstore %arg0[%arg3], %arg1, %arg2 { alignment = 2 } : memref, vector<16xi1>, vector<16xf32> + return +} + +// CHECK-LABEL: func @masked_store_with_alignment +// CHECK: llvm.intr.masked.store %{{.*}} {alignment = 2 : i32} : vector<16xf32>, vector<16xi1> into !llvm.ptr + +// ----- + //===----------------------------------------------------------------------===// // vector.gather //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir index c21de562d05e1..211e16db85a94 100644 --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -1305,6 +1305,26 @@ func.func @store_memref_index_mismatch(%base : memref, %value : vector<16 // ----- +//===----------------------------------------------------------------------===// +// vector.maskedload +//===----------------------------------------------------------------------===// + +func.func @maskedload_negative_alignment(%base: memref<4xi32>, %mask: vector<32xi1>, %pass: vector<1xi32>, %index: index) { + // expected-error@below {{'vector.maskedload' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}} + %val = vector.maskedload %base[%index], %mask, %pass { alignment = -1 } : memref<4xi32>, vector<32xi1>, vector<1xi32> into vector<1xi32> + return +} + +// ----- + +func.func @maskedload_nonpoweroftwo_alignment(%base: memref<4xi32>, %mask: vector<32xi1>, %pass: vector<1xi32>, %index: index) { + // expected-error@below {{'vector.maskedload' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}} + %val = vector.maskedload %base[%index], %mask, %pass { alignment = 3 } : memref<4xi32>, vector<32xi1>, vector<1xi32> into vector<1xi32> + return +} + +// ----- + func.func @maskedload_base_type_mismatch(%base: memref, %mask: vector<16xi1>, %pass: vector<16xf32>) { %c0 = arith.constant 0 : index // expected-error@+1 {{'vector.maskedload' op base and result element type should match}} @@ -1336,6 +1356,26 @@ func.func @maskedload_memref_mismatch(%base: memref, %mask: vector<16xi1> // ----- +//===----------------------------------------------------------------------===// +// vector.maskedstore +//===----------------------------------------------------------------------===// + +func.func @maskedstore_negative_alignment(%base: memref<4xi32>, %mask: vector<32xi1>, %value: vector<1xi32>, %index: index) { + // expected-error@below {{'vector.maskedstore' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}} + vector.maskedstore %base[%index], %mask, %value { alignment = -1 } : memref<4xi32>, vector<32xi1>, vector<1xi32> into vector<1xi32> + return +} + +// ----- + +func.func @maskedstore_nonpoweroftwo_alignment(%base: memref<4xi32>, %mask: vector<32xi1>, %value: vector<1xi32>, %index: index) { + // expected-error@below {{'vector.maskedstore' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}} + vector.maskedstore %base[%index], %mask, %value { alignment = 3 } : memref<4xi32>, vector<32xi1>, vector<1xi32> into vector<1xi32> + return +} + +// ----- + func.func @maskedstore_base_type_mismatch(%base: memref, %mask: vector<16xi1>, %value: vector<16xf32>) { %c0 = arith.constant 0 : index // expected-error@+1 {{'vector.maskedstore' op base and valueToStore element type should match}} @@ -1912,8 +1952,7 @@ func.func @vector_load(%src : memref) { // ----- -func.func @invalid_load_alignment(%memref: memref<4xi32>) { - %c0 = arith.constant 0 : index +func.func @invalid_load_alignment(%memref: memref<4xi32>, %c0: index) { // expected-error @below {{'vector.load' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}} %val = vector.load %memref[%c0] { alignment = -1 } : memref<4xi32>, vector<4xi32> return @@ -1921,6 +1960,14 @@ func.func @invalid_load_alignment(%memref: memref<4xi32>) { // ----- +func.func @invalid_load_alignment(%memref: memref<4xi32>, %c0: index) { + // expected-error @below {{'vector.load' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}} + %val = vector.load %memref[%c0] { alignment = 3 } : memref<4xi32>, vector<4xi32> + return +} + +// ----- + //===----------------------------------------------------------------------===// // vector.store //===----------------------------------------------------------------------===// @@ -1934,8 +1981,15 @@ func.func @vector_store(%dest : memref, %vec : vector<16x16xi8>) { // ----- -func.func @invalid_store_alignment(%memref: memref<4xi32>, %val: vector<4xi32>) { - %c0 = arith.constant 0 : index +func.func @invalid_store_alignment(%memref: memref<4xi32>, %val: vector<4xi32>, %c0: index) { + // expected-error @below {{'vector.store' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}} + vector.store %val, %memref[%c0] { alignment = -1 } : memref<4xi32>, vector<4xi32> + return +} + +// ----- + +func.func @invalid_store_alignment(%memref: memref<4xi32>, %val: vector<4xi32>, %c0: index) { // expected-error @below {{'vector.store' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}} vector.store %val, %memref[%c0] { alignment = 3 } : memref<4xi32>, vector<4xi32> return