diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td index 76b08e664ee76..5e2461b4508e4 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td @@ -87,21 +87,21 @@ class LLVM_TernarySameArgsIntrOpF traits = []> : class LLVM_CountZerosIntrOp traits = []> : LLVM_OneResultIntrOp { let arguments = (ins LLVM_ScalarOrVectorOf:$in, I1Attr:$is_zero_poison); } def LLVM_AbsOp : LLVM_OneResultIntrOp<"abs", [], [0], [Pure], - /*requiresFastmath=*/0, + /*requiresFastmath=*/0, /*requiresArgAndResultAttrs=*/0, /*immArgPositions=*/[1], /*immArgAttrNames=*/["is_int_min_poison"]> { let arguments = (ins LLVM_ScalarOrVectorOf:$in, I1Attr:$is_int_min_poison); } def LLVM_IsFPClass : LLVM_OneResultIntrOp<"is.fpclass", [], [0], [Pure], - /*requiresFastmath=*/0, + /*requiresFastmath=*/0, /*requiresArgAndResultAttrs=*/0, /*immArgPositions=*/[1], /*immArgAttrNames=*/["bit"]> { let arguments = (ins LLVM_ScalarOrVectorOf:$in, I32Attr:$bit); } @@ -360,8 +360,8 @@ def LLVM_LifetimeEndOp : LLVM_LifetimeBaseOp<"lifetime.end">; def LLVM_InvariantStartOp : LLVM_OneResultIntrOp<"invariant.start", [], [1], [DeclareOpInterfaceMethods], - /*requiresFastmath=*/0, /*immArgPositions=*/[0], - /*immArgAttrNames=*/["size"]> { + /*requiresFastmath=*/0, /*requiresArgAndResultAttrs=*/0, + /*immArgPositions=*/[0], /*immArgAttrNames=*/["size"]> { let arguments = (ins I64Attr:$size, LLVM_AnyPointer:$ptr); let results = (outs LLVM_DefaultPointer:$res); let assemblyFormat = "$size `,` $ptr attr-dict `:` qualified(type($ptr))"; @@ -412,6 +412,7 @@ class LLVM_ConstrainedIntr], true : []), /*requiresFastmath=*/0, + /*requiresArgAndResultAttrs=*/0, /*immArgPositions=*/[], /*immArgAttrNames=*/[]> { dag regularArgs = !dag(ins, !listsplat(LLVM_Type, numArgs), !foreach(i, !range(numArgs), "arg_" #i)); @@ -589,7 +590,7 @@ def LLVM_ExpectOp def LLVM_ExpectWithProbabilityOp : LLVM_OneResultIntrOp<"expect.with.probability", [], [0], [Pure, AllTypesMatch<["val", "expected", "res"]>], - /*requiresFastmath=*/0, + /*requiresFastmath=*/0, /*requiresArgAndResultAttrs=*/0, /*immArgPositions=*/[2], /*immArgAttrNames=*/["prob"]> { let arguments = (ins AnySignlessInteger:$val, AnySignlessInteger:$expected, @@ -825,7 +826,7 @@ class LLVM_VecReductionAccBase /*overloadedResults=*/[], /*overloadedOperands=*/[1], /*traits=*/[Pure, SameOperandsAndResultElementType], - /*equiresFastmath=*/1>, + /*requiresFastmath=*/1>, Arguments<(ins element:$start_value, LLVM_VectorOf:$input, DefaultValuedAttr:$fastmathFlags)>; @@ -1069,14 +1070,36 @@ def LLVM_masked_scatter : LLVM_ZeroResultIntrOp<"masked.scatter"> { } /// Create a call to Masked Expand Load intrinsic. -def LLVM_masked_expandload : LLVM_IntrOp<"masked.expandload", [0], [], [], 1> { - let arguments = (ins LLVM_AnyPointer, LLVM_VectorOf, LLVM_AnyVector); +def LLVM_masked_expandload + : LLVM_OneResultIntrOp<"masked.expandload", [0], [], + /*traits=*/[], /*requiresFastMath=*/0, /*requiresArgAndResultAttrs=*/1, + /*immArgPositions=*/[], /*immArgAttrNames=*/[]> { + dag args = (ins LLVM_AnyPointer:$ptr, + LLVM_VectorOf:$mask, + LLVM_AnyVector:$passthru); + + let arguments = !con(args, baseArgs); + + let builders = [ + OpBuilder<(ins "TypeRange":$resTy, "Value":$ptr, "Value":$mask, "Value":$passthru, CArg<"uint64_t", "1">:$align)> + ]; } /// Create a call to Masked Compress Store intrinsic. def LLVM_masked_compressstore - : LLVM_IntrOp<"masked.compressstore", [], [0], [], 0> { - let arguments = (ins LLVM_AnyVector, LLVM_AnyPointer, LLVM_VectorOf); + : LLVM_ZeroResultIntrOp<"masked.compressstore", [0], + /*traits=*/[], /*requiresAccessGroup=*/0, /*requiresAliasAnalysis=*/0, + /*requiresArgAndResultAttrs=*/1, /*requiresOpBundles=*/0, + /*immArgPositions=*/[], /*immArgAttrNames=*/[]> { + dag args = (ins LLVM_AnyVector:$value, + LLVM_AnyPointer:$ptr, + LLVM_VectorOf:$mask); + + let arguments = !con(args, baseArgs); + + let builders = [ + OpBuilder<(ins "Value":$value, "Value":$ptr, "Value":$mask, CArg<"uint64_t", "1">:$align)> + ]; } // @@ -1155,7 +1178,7 @@ def LLVM_vector_insert PredOpTrait<"it is not inserting scalable into fixed-length vectors.", CPred<"!isScalableVectorType($srcvec.getType()) || " "isScalableVectorType($dstvec.getType())">>], - /*requiresFastmath=*/0, + /*requiresFastmath=*/0, /*requiresArgAndResultAttrs=*/0, /*immArgPositions=*/[2], /*immArgAttrNames=*/["pos"]> { let arguments = (ins LLVM_AnyVector:$dstvec, LLVM_AnyVector:$srcvec, I64Attr:$pos); @@ -1189,7 +1212,7 @@ def LLVM_vector_extract PredOpTrait<"it is not extracting scalable from fixed-length vectors.", CPred<"!isScalableVectorType($res.getType()) || " "isScalableVectorType($srcvec.getType())">>], - /*requiresFastmath=*/0, + /*requiresFastmath=*/0, /*requiresArgAndResultAttrs=*/0, /*immArgPositions=*/[1], /*immArgAttrNames=*/["pos"]> { let arguments = (ins LLVM_AnyVector:$srcvec, I64Attr:$pos); let results = (outs LLVM_AnyVector:$res); diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td index a8d7cf2069547..d6aa9580870a8 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td @@ -475,11 +475,12 @@ class LLVM_OneResultIntrOp overloadedResults = [], list overloadedOperands = [], list traits = [], bit requiresFastmath = 0, + bit requiresArgAndResultAttrs = 0, list immArgPositions = [], list immArgAttrNames = []> : LLVM_IntrOp; diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index 30c1d97ba58f1..26c7e666c36a4 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -2054,7 +2054,9 @@ def Vector_GatherOp : Variadic:$indices, VectorOfNonZeroRankOf<[AnyInteger, Index]>:$index_vec, VectorOfNonZeroRankOf<[I1]>:$mask, - AnyVectorOfNonZeroRank:$pass_thru)>, + AnyVectorOfNonZeroRank:$pass_thru, + ConfinedAttr, + [AllAttrOf<[IntPositive, IntPowerOf2]>]>:$alignment)>, Results<(outs AnyVectorOfNonZeroRank:$result)> { let summary = [{ @@ -2085,6 +2087,12 @@ def Vector_GatherOp : during progressively lowering to bring other memory operations closer to hardware ISA support for a gather. + An optional `alignment` attribute allows to specify the byte alignment of the + scatter operation. It must be a positive power of 2. The operation must access + memory at an address aligned to this boundary. Violations may lead to + architecture-specific faults or performance penalties. + A value of 0 indicates no specific alignment requirement. + Examples: ```mlir @@ -2111,6 +2119,20 @@ def Vector_GatherOp : "`into` type($result)"; let hasCanonicalizer = 1; let hasVerifier = 1; + + let builders = [ + OpBuilder<(ins "VectorType":$resultType, + "Value":$base, + "ValueRange":$indices, + "Value":$index_vec, + "Value":$mask, + "Value":$passthrough, + CArg<"llvm::MaybeAlign", "llvm::MaybeAlign()">:$alignment), [{ + return build($_builder, $_state, resultType, base, indices, index_vec, mask, passthrough, + alignment.has_value() ? $_builder.getI64IntegerAttr(alignment->value()) : + nullptr); + }]> + ]; } def Vector_ScatterOp : @@ -2119,7 +2141,9 @@ def Vector_ScatterOp : Variadic:$indices, VectorOfNonZeroRankOf<[AnyInteger, Index]>:$index_vec, VectorOfNonZeroRankOf<[I1]>:$mask, - AnyVectorOfNonZeroRank:$valueToStore)> { + AnyVectorOfNonZeroRank:$valueToStore, + ConfinedAttr, + [AllAttrOf<[IntPositive, IntPowerOf2]>]>:$alignment)> { let summary = [{ scatters elements from a vector into memory as defined by an index vector @@ -2153,6 +2177,12 @@ def Vector_ScatterOp : correspond to those of the `llvm.masked.scatter` [intrinsic](https://llvm.org/docs/LangRef.html#llvm-masked-scatter-intrinsics). + An optional `alignment` attribute allows to specify the byte alignment of the + scatter operation. It must be a positive power of 2. The operation must access + memory at an address aligned to this boundary. Violations may lead to + architecture-specific faults or performance penalties. + A value of 0 indicates no specific alignment requirement. + Examples: ```mlir @@ -2177,6 +2207,19 @@ def Vector_ScatterOp : "type($index_vec) `,` type($mask) `,` type($valueToStore)"; let hasCanonicalizer = 1; let hasVerifier = 1; + + let builders = [ + OpBuilder<(ins "Value":$base, + "ValueRange":$indices, + "Value":$index_vec, + "Value":$mask, + "Value":$valueToStore, + CArg<"llvm::MaybeAlign", "llvm::MaybeAlign()">: $alignment), [{ + return build($_builder, $_state, base, indices, index_vec, mask, valueToStore, + alignment.has_value() ? $_builder.getI64IntegerAttr(alignment->value()) : + nullptr); + }]> + ]; } def Vector_ExpandLoadOp : @@ -2184,7 +2227,9 @@ def Vector_ExpandLoadOp : Arguments<(ins Arg:$base, Variadic:$indices, FixedVectorOfNonZeroRankOf<[I1]>:$mask, - AnyVectorOfNonZeroRank:$pass_thru)>, + AnyVectorOfNonZeroRank:$pass_thru, + ConfinedAttr, + [AllAttrOf<[IntPositive, IntPowerOf2]>]>:$alignment)>, Results<(outs AnyVectorOfNonZeroRank:$result)> { let summary = "reads elements from memory and spreads them into a vector as defined by a mask"; @@ -2216,6 +2261,12 @@ def Vector_ExpandLoadOp : correspond to those of the `llvm.masked.expandload` [intrinsic](https://llvm.org/docs/LangRef.html#llvm-masked-expandload-intrinsics). + An optional `alignment` attribute allows to specify the byte alignment of the + load operation. It must be a positive power of 2. The operation must access + memory at an address aligned to this boundary. Violations may lead to + architecture-specific faults or performance penalties. + A value of 0 indicates no specific alignment requirement. + Note, at the moment this Op is only available for fixed-width vectors. Examples: @@ -2246,6 +2297,19 @@ def Vector_ExpandLoadOp : "type($base) `,` type($mask) `,` type($pass_thru) `into` type($result)"; let hasCanonicalizer = 1; let hasVerifier = 1; + + let builders = [ + OpBuilder<(ins "VectorType":$resultType, + "Value":$base, + "ValueRange":$indices, + "Value":$mask, + "Value":$passthrough, + CArg<"llvm::MaybeAlign", "llvm::MaybeAlign()">:$alignment), [{ + return build($_builder, $_state, resultType, base, indices, mask, passthrough, + alignment.has_value() ? $_builder.getI64IntegerAttr(alignment->value()) : + nullptr); + }]> + ]; } def Vector_CompressStoreOp : @@ -2253,7 +2317,9 @@ def Vector_CompressStoreOp : Arguments<(ins Arg:$base, Variadic:$indices, FixedVectorOfNonZeroRankOf<[I1]>:$mask, - AnyVectorOfNonZeroRank:$valueToStore)> { + AnyVectorOfNonZeroRank:$valueToStore, + ConfinedAttr, + [AllAttrOf<[IntPositive, IntPowerOf2]>]>:$alignment)> { let summary = "writes elements selectively from a vector as defined by a mask"; @@ -2284,6 +2350,12 @@ def Vector_CompressStoreOp : correspond to those of the `llvm.masked.compressstore` [intrinsic](https://llvm.org/docs/LangRef.html#llvm-masked-compressstore-intrinsics). + An optional `alignment` attribute allows to specify the byte alignment of the + store operation. It must be a positive power of 2. The operation must access + memory at an address aligned to this boundary. Violations may lead to + architecture-specific faults or performance penalties. + A value of 0 indicates no specific alignment requirement. + Note, at the moment this Op is only available for fixed-width vectors. Examples: @@ -2312,6 +2384,17 @@ def Vector_CompressStoreOp : "type($base) `,` type($mask) `,` type($valueToStore)"; let hasCanonicalizer = 1; let hasVerifier = 1; + let builders = [ + OpBuilder<(ins "Value":$base, + "ValueRange":$indices, + "Value":$mask, + "Value":$valueToStore, + CArg<"llvm::MaybeAlign", "llvm::MaybeAlign()">:$alignment), [{ + return build($_builder, $_state, base, indices, valueToStore, mask, + alignment.has_value() ? $_builder.getI64IntegerAttr(alignment->value()) : + nullptr); + }]> + ]; } def Vector_ShapeCastOp : diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index f9e2a01dbf969..8117c0597f265 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -299,8 +299,9 @@ class VectorGatherOpConversion } // Resolve alignment. - unsigned align; - if (failed(getVectorToLLVMAlignment(*this->getTypeConverter(), vType, + unsigned align = gather.getAlignment().value_or(0); + if (!align && + failed(getVectorToLLVMAlignment(*this->getTypeConverter(), vType, memRefType, align, useVectorAlignment))) return rewriter.notifyMatchFailure(gather, "could not resolve alignment"); @@ -354,8 +355,9 @@ class VectorScatterOpConversion } // Resolve alignment. - unsigned align; - if (failed(getVectorToLLVMAlignment(*this->getTypeConverter(), vType, + unsigned align = scatter.getAlignment().value_or(0); + if (!align && + failed(getVectorToLLVMAlignment(*this->getTypeConverter(), vType, memRefType, align, useVectorAlignment))) return rewriter.notifyMatchFailure(scatter, "could not resolve alignment"); @@ -399,8 +401,14 @@ class VectorExpandLoadOpConversion Value ptr = getStridedElementPtr(rewriter, loc, memRefType, adaptor.getBase(), adaptor.getIndices()); + // From: + // https://llvm.org/docs/LangRef.html#llvm-masked-expandload-intrinsics + // The pointer alignment defaults to 1. + uint64_t alignment = expand.getAlignment().value_or(1); + rewriter.replaceOpWithNewOp( - expand, vtype, ptr, adaptor.getMask(), adaptor.getPassThru()); + expand, vtype, ptr, adaptor.getMask(), adaptor.getPassThru(), + alignment); return success(); } }; @@ -421,8 +429,13 @@ class VectorCompressStoreOpConversion Value ptr = getStridedElementPtr(rewriter, loc, memRefType, adaptor.getBase(), adaptor.getIndices()); + // From: + // https://llvm.org/docs/LangRef.html#llvm-masked-compressstore-intrinsics + // The pointer alignment defaults to 1. + uint64_t alignment = compress.getAlignment().value_or(1); + rewriter.replaceOpWithNewOp( - compress, adaptor.getValueToStore(), ptr, adaptor.getMask()); + compress, adaptor.getValueToStore(), ptr, adaptor.getMask(), alignment); return success(); } }; diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp index 422039f81855a..c7a193c31a59f 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -141,6 +141,38 @@ static ParseResult parseLLVMLinkage(OpAsmParser &p, LinkageAttr &val) { return success(); } +static ArrayAttr getLLVMAlignParamForCompressExpand(OpBuilder &builder, + bool isExpandLoad, + uint64_t alignment = 1) { + // From + // https://llvm.org/docs/LangRef.html#llvm-masked-expandload-intrinsics + // https://llvm.org/docs/LangRef.html#llvm-masked-compressstore-intrinsics + // + // The pointer alignment defaults to 1. + if (alignment == 1) { + return nullptr; + } + + auto emptyDictAttr = builder.getDictionaryAttr({}); + auto alignmentAttr = builder.getI64IntegerAttr(alignment); + auto namedAttr = + builder.getNamedAttr(LLVMDialect::getAlignAttrName(), alignmentAttr); + SmallVector attrs = {namedAttr}; + auto alignDictAttr = builder.getDictionaryAttr(attrs); + // From + // https://llvm.org/docs/LangRef.html#llvm-masked-expandload-intrinsics + // https://llvm.org/docs/LangRef.html#llvm-masked-compressstore-intrinsics + // + // The align parameter attribute can be provided for [expandload]'s first + // argument. The align parameter attribute can be provided for + // [compressstore]'s second argument. + int pos = isExpandLoad ? 0 : 1; + return pos == 0 ? builder.getArrayAttr( + {alignDictAttr, emptyDictAttr, emptyDictAttr}) + : builder.getArrayAttr( + {emptyDictAttr, alignDictAttr, emptyDictAttr}); +} + //===----------------------------------------------------------------------===// // Operand bundle helpers. //===----------------------------------------------------------------------===// @@ -4116,6 +4148,33 @@ LogicalResult LLVM::masked_scatter::verify() { return success(); } +//===----------------------------------------------------------------------===// +// masked_expandload (intrinsic) +//===----------------------------------------------------------------------===// + +void LLVM::masked_expandload::build(OpBuilder &builder, OperationState &state, + mlir::TypeRange resTys, Value ptr, + Value mask, Value passthru, + uint64_t align) { + ArrayAttr callArgs = getLLVMAlignParamForCompressExpand(builder, true, align); + build(builder, state, resTys, ptr, mask, passthru, /*arg_attrs=*/callArgs, + /*res_attrs=*/nullptr); +} + +//===----------------------------------------------------------------------===// +// masked_compressstore (intrinsic) +//===----------------------------------------------------------------------===// + +void LLVM::masked_compressstore::build(OpBuilder &builder, + OperationState &state, Value value, + Value ptr, Value mask, uint64_t align) { + + ArrayAttr callArgs = + getLLVMAlignParamForCompressExpand(builder, false, align); + build(builder, state, value, ptr, mask, /*arg_attrs=*/callArgs, + /*res_attrs=*/nullptr); +} + //===----------------------------------------------------------------------===// // InlineAsmOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir index 5a424a8ac0d5f..41ee4230ee913 100644 --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir @@ -2042,6 +2042,16 @@ func.func @gather_1d_from_2d_scalable(%arg0: memref<4x?xf32>, %arg1: vector<[4]x // ----- +func.func @gather_with_alignment(%arg0: memref, %arg1: vector<3xi32>, %arg2: vector<3xi1>, %arg3: vector<3xf32>, %0: index) -> vector<3xf32> { + %1 = vector.gather %arg0[%0][%arg1], %arg2, %arg3 {alignment = 8} : memref, vector<3xi32>, vector<3xi1>, vector<3xf32> into vector<3xf32> + return %1 : vector<3xf32> +} + +// CHECK-LABEL: func @gather_with_alignment +// CHECK: llvm.intr.masked.gather %{{.*}}, %{{.*}}, %{{.*}} {alignment = 8 : i32} : (vector<3x!llvm.ptr>, vector<3xi1>, vector<3xf32>) -> vector<3xf32> + +// ----- + //===----------------------------------------------------------------------===// // vector.scatter //===----------------------------------------------------------------------===// @@ -2118,6 +2128,17 @@ func.func @scatter_1d_into_2d_scalable(%arg0: memref<4x?xf32>, %arg1: vector<[4] // CHECK: %[[P:.*]] = llvm.getelementptr %[[B]][%{{.*}}] : (!llvm.ptr, vector<[4]xi32>) -> vector<[4]x!llvm.ptr>, f32 // CHECK: llvm.intr.masked.scatter %{{.*}}, %[[P]], %{{.*}} {alignment = 4 : i32} : vector<[4]xf32>, vector<[4]xi1> into vector<[4]x!llvm.ptr> +// ----- + +func.func @scatter_with_alignment(%arg0: memref, %arg1: vector<3xi32>, %arg2: vector<3xi1>, %arg3: vector<3xf32>, %0: index) { + vector.scatter %arg0[%0][%arg1], %arg2, %arg3 { alignment = 8 } : memref, vector<3xi32>, vector<3xi1>, vector<3xf32> + return +} + +// CHECK-LABEL: func @scatter_with_alignment +// CHECK: llvm.intr.masked.scatter %{{.*}}, %{{.*}}, %{{.*}} {alignment = 8 : i32} : vector<3xf32>, vector<3xi1> into vector<3x!llvm.ptr> + + // ----- //===----------------------------------------------------------------------===// @@ -2149,6 +2170,15 @@ func.func @expand_load_op_index(%arg0: memref, %arg1: vector<11xi1>, %a // ----- +func.func @expand_load_op_with_alignment(%arg0: memref, %arg1: vector<11xi1>, %arg2: vector<11xindex>, %c0: index) -> vector<11xindex> { + %0 = vector.expandload %arg0[%c0], %arg1, %arg2 { alignment = 8 } : memref, vector<11xi1>, vector<11xindex> into vector<11xindex> + return %0 : vector<11xindex> +} +// CHECK-LABEL: func @expand_load_op_with_alignment +// CHECK: %{{.*}} = "llvm.intr.masked.expandload"(%{{.*}}, %{{.*}}, %{{.*}}) <{arg_attrs = [{llvm.align = 8 : i64}, {}, {}]}> : (!llvm.ptr, vector<11xi1>, vector<11xi64>) -> vector<11xi64> + +// ----- + //===----------------------------------------------------------------------===// // vector.compressstore //===----------------------------------------------------------------------===// @@ -2177,6 +2207,15 @@ func.func @compress_store_op_index(%arg0: memref, %arg1: vector<11xi1>, // ----- +func.func @compress_store_op_with_alignment(%arg0: memref, %arg1: vector<11xi1>, %arg2: vector<11xindex>, %c0: index) { + vector.compressstore %arg0[%c0], %arg1, %arg2 { alignment = 8 } : memref, vector<11xi1>, vector<11xindex> + return +} +// CHECK-LABEL: func @compress_store_op_with_alignment +// CHECK: "llvm.intr.masked.compressstore"(%{{.*}}, %{{.*}}, %{{.*}}) <{arg_attrs = [{}, {llvm.align = 8 : i64}, {}]}> : (vector<11xi64>, !llvm.ptr, vector<11xi1>) -> () + +// ----- + //===----------------------------------------------------------------------===// // vector.splat //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir index 211e16db85a94..2e72bf036fa71 100644 --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -1317,7 +1317,7 @@ func.func @maskedload_negative_alignment(%base: memref<4xi32>, %mask: vector<32x // ----- -func.func @maskedload_nonpoweroftwo_alignment(%base: memref<4xi32>, %mask: vector<32xi1>, %pass: vector<1xi32>, %index: index) { +func.func @maskedload_non_power_of_2_alignment(%base: memref<4xi32>, %mask: vector<32xi1>, %pass: vector<1xi32>, %index: index) { // expected-error@below {{'vector.maskedload' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}} %val = vector.maskedload %base[%index], %mask, %pass { alignment = 3 } : memref<4xi32>, vector<32xi1>, vector<1xi32> into vector<1xi32> return @@ -1368,7 +1368,7 @@ func.func @maskedstore_negative_alignment(%base: memref<4xi32>, %mask: vector<32 // ----- -func.func @maskedstore_nonpoweroftwo_alignment(%base: memref<4xi32>, %mask: vector<32xi1>, %value: vector<1xi32>, %index: index) { +func.func @maskedstore_non_power_of_2_alignment(%base: memref<4xi32>, %mask: vector<32xi1>, %value: vector<1xi32>, %index: index) { // expected-error@below {{'vector.maskedstore' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}} vector.maskedstore %base[%index], %mask, %value { alignment = 3 } : memref<4xi32>, vector<32xi1>, vector<1xi32> into vector<1xi32> return @@ -1470,6 +1470,24 @@ func.func @gather_pass_thru_type_mismatch(%base: memref, %indices: vector // ----- +func.func @gather_negative_alignment(%base: memref<16xf32>, %indices: vector<16xi32>, + %mask: vector<16xi1>, %pass_thru: vector<16xf32>, %c0 : index) { + // expected-error@+2 {{'vector.gather' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}} + %0 = vector.gather %base[%c0][%indices], %mask, %pass_thru + { alignment = -1 } : memref<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32> +} + +// ----- + +func.func @gather_non_power_of_two_alignment(%base: memref<16xf32>, %indices: vector<16xi32>, + %mask: vector<16xi1>, %pass_thru: vector<16xf32>, %c0 : index) { + // expected-error@+2 {{'vector.gather' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}} + %0 = vector.gather %base[%c0][%indices], %mask, %pass_thru + { alignment = 3 } : memref<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32> +} + +// ----- + func.func @scatter_to_vector(%base: vector<16xf32>, %indices: vector<16xi32>, %mask: vector<16xi1>, %pass_thru: vector<16xf32>) { %c0 = arith.constant 0 : index @@ -1531,6 +1549,24 @@ func.func @scatter_dim_mask_mismatch(%base: memref, %indices: vector<16xi // ----- +func.func @scatter_negative_alignment(%base: memref, %indices: vector<16xi32>, + %mask: vector<16xi1>, %value: vector<16xf32>, %c0: index) { + // expected-error@+1 {{'vector.scatter' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}} + vector.scatter %base[%c0][%indices], %mask, %value { alignment = -1 } + : memref, vector<16xi32>, vector<16xi1>, vector<16xf32> +} + +// ----- + +func.func @scatter_non_power_of_2_alignment(%base: memref, %indices: vector<16xi32>, + %mask: vector<16xi1>, %value: vector<16xf32>, %c0: index) { + // expected-error@+1 {{'vector.scatter' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}} + vector.scatter %base[%c0][%indices], %mask, %value { alignment = 3 } + : memref, vector<16xi32>, vector<16xi1>, vector<16xf32> +} + +// ----- + func.func @expand_base_type_mismatch(%base: memref, %mask: vector<16xi1>, %pass_thru: vector<16xf32>) { %c0 = arith.constant 0 : index // expected-error@+1 {{'vector.expandload' op base and result element type should match}} @@ -1571,6 +1607,20 @@ func.func @expand_memref_mismatch(%base: memref, %mask: vector<16xi1>, // ----- +func.func @expand_negative_alignment(%base: memref, %mask: vector<16xi1>, %pass_thru: vector<16xf32>, %c0: index) { + // expected-error@+1 {{'vector.expandload' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}} + %0 = vector.expandload %base[%c0], %mask, %pass_thru { alignment = -1 } : memref, vector<16xi1>, vector<16xf32> into vector<16xf32> +} + +// ----- + +func.func @expand_non_power_of_2_alignment(%base: memref, %mask: vector<16xi1>, %pass_thru: vector<16xf32>, %c0: index) { + // expected-error@+1 {{'vector.expandload' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}} + %0 = vector.expandload %base[%c0], %mask, %pass_thru { alignment = 3 } : memref, vector<16xi1>, vector<16xf32> into vector<16xf32> +} + +// ----- + func.func @compress_base_type_mismatch(%base: memref, %mask: vector<16xi1>, %value: vector<16xf32>) { %c0 = arith.constant 0 : index // expected-error@+1 {{'vector.compressstore' op base and valueToStore element type should match}} @@ -1603,6 +1653,20 @@ func.func @compress_memref_mismatch(%base: memref, %mask: vector<16xi1> // ----- +func.func @compress_negative_alignment(%base: memref, %mask: vector<16xi1>, %value: vector<16xf32>, %c0: index) { + // expected-error @below {{'vector.compressstore' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}} + vector.compressstore %base[%c0], %mask, %value { alignment = -1 } : memref, vector<16xi1>, vector<16xf32> +} + +// ----- + +func.func @compress_non_power_of_2_alignment(%base: memref, %mask: vector<16xi1>, %value: vector<16xf32>, %c0: index) { + // expected-error @below {{'vector.compressstore' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}} + vector.compressstore %base[%c0], %mask, %value { alignment = 3 } : memref, vector<16xi1>, vector<16xf32> +} + +// ----- + func.func @scan_reduction_dim_constraint(%arg0: vector<2x3xi32>, %arg1: vector<3xi32>) -> vector<3xi32> { // expected-error@+1 {{'vector.scan' op reduction dimension 5 has to be less than 2}} %0:2 = vector.scan , %arg0, %arg1 {inclusive = true, reduction_dim = 5} : @@ -1952,7 +2016,7 @@ func.func @vector_load(%src : memref) { // ----- -func.func @invalid_load_alignment(%memref: memref<4xi32>, %c0: index) { +func.func @load_negative_alignment(%memref: memref<4xi32>, %c0: index) { // expected-error @below {{'vector.load' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}} %val = vector.load %memref[%c0] { alignment = -1 } : memref<4xi32>, vector<4xi32> return @@ -1960,7 +2024,7 @@ func.func @invalid_load_alignment(%memref: memref<4xi32>, %c0: index) { // ----- -func.func @invalid_load_alignment(%memref: memref<4xi32>, %c0: index) { +func.func @load_non_pow_of_2_alignment(%memref: memref<4xi32>, %c0: index) { // expected-error @below {{'vector.load' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}} %val = vector.load %memref[%c0] { alignment = 3 } : memref<4xi32>, vector<4xi32> return @@ -1981,7 +2045,7 @@ func.func @vector_store(%dest : memref, %vec : vector<16x16xi8>) { // ----- -func.func @invalid_store_alignment(%memref: memref<4xi32>, %val: vector<4xi32>, %c0: index) { +func.func @store_negative_alignment(%memref: memref<4xi32>, %val: vector<4xi32>, %c0: index) { // expected-error @below {{'vector.store' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}} vector.store %val, %memref[%c0] { alignment = -1 } : memref<4xi32>, vector<4xi32> return @@ -1989,7 +2053,7 @@ func.func @invalid_store_alignment(%memref: memref<4xi32>, %val: vector<4xi32>, // ----- -func.func @invalid_store_alignment(%memref: memref<4xi32>, %val: vector<4xi32>, %c0: index) { +func.func @store_non_pow_of_2_alignment(%memref: memref<4xi32>, %val: vector<4xi32>, %c0: index) { // expected-error @below {{'vector.store' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}} vector.store %val, %memref[%c0] { alignment = 3 } : memref<4xi32>, vector<4xi32> return diff --git a/mlir/test/Target/LLVMIR/Import/intrinsic.ll b/mlir/test/Target/LLVMIR/Import/intrinsic.ll index 9f882ad6f22e8..07d22120153fe 100644 --- a/mlir/test/Target/LLVMIR/Import/intrinsic.ll +++ b/mlir/test/Target/LLVMIR/Import/intrinsic.ll @@ -545,6 +545,15 @@ define void @masked_expand_compress_intrinsics(ptr %0, <7 x i1> %1, <7 x float> ret void } +; CHECK-LABEL: llvm.func @masked_expand_compress_intrinsics_with_alignment +define void @masked_expand_compress_intrinsics_with_alignment(ptr %0, <7 x i1> %1, <7 x float> %2) { + ; CHECK: %[[val1:.+]] = "llvm.intr.masked.expandload"(%{{.*}}, %{{.*}}, %{{.*}}) <{arg_attrs = [{llvm.align = 8 : i64}, {}, {}]}> : (!llvm.ptr, vector<7xi1>, vector<7xf32>) -> vector<7xf32> + %4 = call <7 x float> @llvm.masked.expandload.v7f32(ptr align 8 %0, <7 x i1> %1, <7 x float> %2) + ; CHECK: "llvm.intr.masked.compressstore"(%[[val1]], %{{.*}}, %{{.*}}) <{arg_attrs = [{}, {llvm.align = 8 : i64}, {}]}> : (vector<7xf32>, !llvm.ptr, vector<7xi1>) -> () + call void @llvm.masked.compressstore.v7f32(<7 x float> %4, ptr align 8 %0, <7 x i1> %1) + ret void +} + ; CHECK-LABEL: llvm.func @annotate_intrinsics define void @annotate_intrinsics(ptr %var, ptr %ptr, i16 %int, ptr %annotation, ptr %fileName, i32 %line, ptr %args) { ; CHECK: "llvm.intr.var.annotation"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (!llvm.ptr, !llvm.ptr, !llvm.ptr, i32, !llvm.ptr) -> () diff --git a/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir b/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir index 2b420ed246fb2..c99dde36f5ccb 100644 --- a/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir +++ b/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir @@ -577,6 +577,17 @@ llvm.func @masked_expand_compress_intrinsics(%ptr: !llvm.ptr, %mask: vector<7xi1 llvm.return } +// CHECK-LABEL: @masked_expand_compress_intrinsics_with_alignment +llvm.func @masked_expand_compress_intrinsics_with_alignment(%ptr: !llvm.ptr, %mask: vector<7xi1>, %passthru: vector<7xf32>) { + // CHECK: call <7 x float> @llvm.masked.expandload.v7f32(ptr align 8 %{{.*}}, <7 x i1> %{{.*}}, <7 x float> %{{.*}}) + %0 = "llvm.intr.masked.expandload"(%ptr, %mask, %passthru) {arg_attrs = [{llvm.align = 8 : i32}, {}, {}]} + : (!llvm.ptr, vector<7xi1>, vector<7xf32>) -> (vector<7xf32>) + // CHECK: call void @llvm.masked.compressstore.v7f32(<7 x float> %{{.*}}, ptr align 8 %{{.*}}, <7 x i1> %{{.*}}) + "llvm.intr.masked.compressstore"(%0, %ptr, %mask) {arg_attrs = [{}, {llvm.align = 8 : i32}, {}]} + : (vector<7xf32>, !llvm.ptr, vector<7xi1>) -> () + llvm.return +} + // CHECK-LABEL: @annotate_intrinsics llvm.func @annotate_intrinsics(%var: !llvm.ptr, %int: i16, %ptr: !llvm.ptr, %annotation: !llvm.ptr, %fileName: !llvm.ptr, %line: i32, %attr: !llvm.ptr) { // CHECK: call void @llvm.var.annotation.p0.p0(ptr %{{.*}}, ptr %{{.*}}, ptr %{{.*}}, i32 %{{.*}}, ptr %{{.*}})