diff --git a/build_tools/llvm_version.txt b/build_tools/llvm_version.txt index 2add8d062..2916a5cc6 100644 --- a/build_tools/llvm_version.txt +++ b/build_tools/llvm_version.txt @@ -1 +1 @@ -228e96b28a84828e1720c387a339a7e68dbdc029 +92164faf17d553359418b9f49c1a41d680d0de49 diff --git a/build_tools/patches/0001-Add-support-for-VectorAnyINTEL-capability.patch b/build_tools/patches/0001-Add-support-for-VectorAnyINTEL-capability.patch index 1fe88c2f0..d5c96ef21 100644 --- a/build_tools/patches/0001-Add-support-for-VectorAnyINTEL-capability.patch +++ b/build_tools/patches/0001-Add-support-for-VectorAnyINTEL-capability.patch @@ -1,6 +1,6 @@ -From 4167e203a75627ca13d8ea7560aaea9a6bb506f0 Mon Sep 17 00:00:00 2001 +From e0189210ee8e532bea15f0592a801f2264b62834 Mon Sep 17 00:00:00 2001 From: Garra1980 -Date: Sat, 12 Jul 2025 00:39:57 +0200 +Date: Wed, 13 Aug 2025 17:08:31 +0200 Subject: [PATCH] Add support for VectorAnyINTEL capability --- @@ -12,22 +12,22 @@ Subject: [PATCH] Add support for VectorAnyINTEL capability .../arith-to-spirv-unsupported.mlir | 4 +- .../ArithToSPIRV/arith-to-spirv.mlir | 34 +++++ .../FuncToSPIRV/types-to-spirv.mlir | 17 ++- - .../test/Dialect/SPIRV/IR/arithmetic-ops.mlir | 2 +- + .../test/Dialect/SPIRV/IR/arithmetic-ops.mlir | 1 + mlir/test/Dialect/SPIRV/IR/bit-ops.mlir | 6 +- mlir/test/Dialect/SPIRV/IR/gl-ops.mlir | 2 +- - mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir | 4 +- + mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir | 3 +- mlir/test/Dialect/SPIRV/IR/logical-ops.mlir | 2 +- .../Dialect/SPIRV/IR/non-uniform-ops.mlir | 12 +- mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir | 34 ++--- mlir/test/Target/SPIRV/arithmetic-ops.mlir | 6 +- mlir/test/Target/SPIRV/ocl-ops.mlir | 6 + - 17 files changed, 322 insertions(+), 69 deletions(-) + 17 files changed, 322 insertions(+), 67 deletions(-) diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td -index 910418f1706a..29af93d8e752 100644 +index bdfd728d1d0b..31e8bc288d5b 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td -@@ -4217,7 +4217,14 @@ def SPIRV_BFloat16KHR : TypeAlias; +@@ -4233,7 +4233,14 @@ def SPIRV_BFloat16KHR : TypeAlias; def SPIRV_Float : FloatOfWidths<[16, 32, 64]>; def SPIRV_Float16or32 : FloatOfWidths<[16, 32]>; def SPIRV_AnyFloat : AnyTypeOf<[SPIRV_Float, SPIRV_BFloat16KHR]>; @@ -43,17 +43,17 @@ index 910418f1706a..29af93d8e752 100644 [SPIRV_Bool, SPIRV_Integer, SPIRV_AnyFloat]>; // Component type check is done in the type parser for the following SPIR-V // dialect-specific types so we use "Any" here. -@@ -4270,7 +4277,7 @@ class SPIRV_MatrixOfType allowedTypes> : +@@ -4286,7 +4293,7 @@ class SPIRV_MatrixOfType allowedTypes> : "Matrix">; class SPIRV_VectorOf : -- VectorOfLengthAndType<[2, 3, 4, 8, 16], [type]>; -+ VectorOfLengthRangeAndType<[2, 0xFFFFFFFF], [type]>; +- FixedVectorOfLengthAndType<[2, 3, 4, 8, 16], [type]>; ++ FixedVectorOfLengthRangeAndType<[2, 0xFFFFFFFF], [type]>; class SPIRV_ScalarOrVectorOf : AnyTypeOf<[type, SPIRV_VectorOf]>; diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td -index 45ec1846580f..6ca59f91eee9 100644 +index b682f4c025a4..298553c83947 100644 --- a/mlir/include/mlir/IR/CommonTypeConstraints.td +++ b/mlir/include/mlir/IR/CommonTypeConstraints.td @@ -648,6 +648,92 @@ class ScalableVectorOfRankAndLengthAndType allowedRanks, @@ -150,10 +150,10 @@ index 45ec1846580f..6ca59f91eee9 100644 // Negative values for `n` index in reverse. class ShapedTypeWithNthDimOfSize allowedSizes> : Type< diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp -index 88c7adf3dfcb..d29c88a1fd53 100644 +index fcf152649197..bbc538a19840 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp -@@ -188,9 +188,12 @@ static Type parseAndVerifyType(SPIRVDialect const &dialect, +@@ -186,9 +186,12 @@ static Type parseAndVerifyType(SPIRVDialect const &dialect, parser.emitError(typeLoc, "only 1-D vector allowed but found ") << t; return Type(); } @@ -169,10 +169,10 @@ index 88c7adf3dfcb..d29c88a1fd53 100644 return Type(); } diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp -index 2b90df42af5c..34f25f2b3bc9 100644 +index ddb342621f37..952c474fd34d 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp -@@ -101,9 +101,10 @@ bool CompositeType::classof(Type type) { +@@ -98,9 +98,10 @@ bool CompositeType::classof(Type type) { } bool CompositeType::isValid(VectorType type) { @@ -186,7 +186,7 @@ index 2b90df42af5c..34f25f2b3bc9 100644 } Type CompositeType::getElementType(unsigned index) const { -@@ -174,7 +175,21 @@ void CompositeType::getCapabilities( +@@ -171,7 +172,21 @@ void CompositeType::getCapabilities( .Case([&](VectorType type) { auto vecSize = getNumElements(); if (vecSize == 8 || vecSize == 16) { @@ -210,10 +210,10 @@ index 2b90df42af5c..34f25f2b3bc9 100644 capabilities.push_back(ref); } diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp -index 1e7bb046d375..24e633da72aa 100644 +index 49f4ce8de7c7..eef55a427486 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp -@@ -87,9 +87,13 @@ static std::optional> getTargetShape(VectorType vecType) { +@@ -84,9 +84,13 @@ static std::optional> getTargetShape(VectorType vecType) { template static LogicalResult checkExtensionRequirements( LabelT label, const spirv::TargetEnv &targetEnv, @@ -229,7 +229,7 @@ index 1e7bb046d375..24e633da72aa 100644 continue; LLVM_DEBUG({ -@@ -115,9 +119,13 @@ static LogicalResult checkExtensionRequirements( +@@ -112,9 +116,13 @@ static LogicalResult checkExtensionRequirements( template static LogicalResult checkCapabilityRequirements( LabelT label, const spirv::TargetEnv &targetEnv, @@ -245,7 +245,7 @@ index 1e7bb046d375..24e633da72aa 100644 continue; LLVM_DEBUG({ -@@ -134,6 +142,55 @@ static LogicalResult checkCapabilityRequirements( +@@ -131,6 +139,55 @@ static LogicalResult checkCapabilityRequirements( return success(); } @@ -301,7 +301,7 @@ index 1e7bb046d375..24e633da72aa 100644 /// Returns true if the given `storageClass` needs explicit layout when used in /// Shader environments. static bool needsExplicitLayout(spirv::StorageClass storageClass) { -@@ -279,11 +336,14 @@ convertScalarType(const spirv::TargetEnv &targetEnv, +@@ -284,11 +341,14 @@ convertScalarType(const spirv::TargetEnv &targetEnv, return nullptr; } @@ -316,7 +316,7 @@ index 1e7bb046d375..24e633da72aa 100644 auto intType = cast(type); LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n"); return IntegerType::get(targetEnv.getContext(), /*width=*/32, -@@ -358,10 +418,13 @@ convertVectorType(const spirv::TargetEnv &targetEnv, +@@ -402,10 +462,13 @@ convertVectorType(const spirv::TargetEnv &targetEnv, if (type.getRank() <= 1 && type.getNumElements() == 1) return elementType; @@ -334,7 +334,7 @@ index 1e7bb046d375..24e633da72aa 100644 return nullptr; } -@@ -383,16 +446,40 @@ convertVectorType(const spirv::TargetEnv &targetEnv, +@@ -427,16 +490,40 @@ convertVectorType(const spirv::TargetEnv &targetEnv, cast(type).getExtensions(extensions, storageClass); cast(type).getCapabilities(capabilities, storageClass); @@ -382,7 +382,7 @@ index 1e7bb046d375..24e633da72aa 100644 } static Type -@@ -1563,16 +1650,18 @@ bool SPIRVConversionTarget::isLegalOp(Operation *op) { +@@ -1693,16 +1780,18 @@ bool SPIRVConversionTarget::isLegalOp(Operation *op) { SmallVector, 4> typeExtensions; SmallVector, 8> typeCapabilities; for (Type valueType : valueTypes) { @@ -426,10 +426,10 @@ index 9d7ab2be096e..3aa22e261f7c 100644 } diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir -index 1abe0fd2ec46..f64436fa2632 100644 +index 6e2352e706ac..4c9d2e147bc6 100644 --- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir +++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir -@@ -1462,6 +1462,40 @@ func.func @ops_flags(%arg0: i64, %arg1: i64) { +@@ -1479,6 +1479,40 @@ func.func @ops_flags(%arg0: i64, %arg1: i64) { %2 = arith.muli %arg0, %arg1 overflow : i64 // CHECK: %{{.*}} = spirv.IMul %{{.*}}, %{{.*}} : i64 %3 = arith.muli %arg0, %arg1 overflow : i64 @@ -471,10 +471,10 @@ index 1abe0fd2ec46..f64436fa2632 100644 } diff --git a/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir b/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir -index 1737f4a906bf..13f4e17167ef 100644 +index 0c77c8833457..d6628afb7329 100644 --- a/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir +++ b/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir -@@ -345,8 +345,21 @@ module attributes { +@@ -347,8 +347,21 @@ module attributes { spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> } { @@ -499,28 +499,27 @@ index 1737f4a906bf..13f4e17167ef 100644 } // end module diff --git a/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir b/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir -index 3adafc15c79f..f75fd6cb0d39 100644 +index c703274bda57..670edc9deb91 100644 --- a/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir -@@ -348,7 +348,7 @@ func.func @dot(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> f16 { - // ----- +@@ -349,6 +349,7 @@ func.func @dot(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> f16 { func.func @dot(%arg0: vector<4xi32>, %arg1: vector<4xi32>) -> i32 { -- // expected-error @+1 {{'spirv.Dot' op operand #0 must be vector of 16/32/64-bit float or BFloat16 values of length 2/3/4/8/16}} -+ // expected-error @+1 {{op operand #0 must be vector of 16/32/64-bit float or BFloat16 values of length 2-4294967295, but got 'vector<4xi32>'}} + // expected-error @+1 {{'spirv.Dot' op operand #0 must be fixed-length vector of 16/32/64-bit float or BFloat16 values of length 2/3/4/8/16}} ++ // expected-error @+1 {{'spirv.Dot' op operand #0 must be fixed-length vector of 16/32/64-bit float or BFloat16 values of length 2-4294967295}} %0 = spirv.Dot %arg0, %arg1 : vector<4xi32> -> i32 return %0 : i32 } diff --git a/mlir/test/Dialect/SPIRV/IR/bit-ops.mlir b/mlir/test/Dialect/SPIRV/IR/bit-ops.mlir -index f3f0ebf60f46..1138f38bcef2 100644 +index 4bdac198a1e8..dee8c7f9a65e 100644 --- a/mlir/test/Dialect/SPIRV/IR/bit-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/bit-ops.mlir @@ -137,7 +137,7 @@ func.func @bitwise_or_all_ones_vector(%arg: vector<3xi8>) -> vector<3xi8> { // ----- func.func @bitwise_or_float(%arg0: f16, %arg1: f16) -> f16 { -- // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values of length 2/3/4}} -+ // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values of length 2-4294967295}} +- // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or fixed-length vector of 8/16/32/64-bit integer values of length 2/3/4}} ++ // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or fixed-length vector of 8/16/32/64-bit integer values of length 2-4294967295}} %0 = spirv.BitwiseOr %arg0, %arg1 : f16 return %0 : f16 } @@ -528,8 +527,8 @@ index f3f0ebf60f46..1138f38bcef2 100644 // ----- func.func @bitwise_xor_float(%arg0: f16, %arg1: f16) -> f16 { -- // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values of length 2/3/4}} -+ // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values of length 2-9223372036854775807}} +- // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or fixed-length vector of 8/16/32/64-bit integer values of length 2/3/4}} ++ // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or fixed-length vector of 8/16/32/64-bit integer values of length 2-9223372036854775807}} %0 = spirv.BitwiseXor %arg0, %arg1 : f16 return %0 : f16 } @@ -537,69 +536,68 @@ index f3f0ebf60f46..1138f38bcef2 100644 // ----- func.func @bitwise_and_float(%arg0: f16, %arg1: f16) -> f16 { -- // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values of length 2/3/4}} -+ // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values of length 2-9223372036854775807}} +- // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or fixed-length vector of 8/16/32/64-bit integer values of length 2/3/4}} ++ // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or fixed-length vector of 8/16/32/64-bit integer values of length 2-9223372036854775807}} %0 = spirv.BitwiseAnd %arg0, %arg1 : f16 return %0 : f16 } diff --git a/mlir/test/Dialect/SPIRV/IR/gl-ops.mlir b/mlir/test/Dialect/SPIRV/IR/gl-ops.mlir -index 5c5d94c40e57..8edaa3762c23 100644 +index fd8a2ffbbddf..011759101a74 100644 --- a/mlir/test/Dialect/SPIRV/IR/gl-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/gl-ops.mlir @@ -27,7 +27,7 @@ func.func @exp(%arg0 : i32) -> () { // ----- func.func @exp(%arg0 : vector<5xf32>) -> () { -- // expected-error @+1 {{op operand #0 must be 16/32-bit float or vector of 16/32-bit float values of length 2/3/4}} +- // expected-error @+1 {{op operand #0 must be 16/32-bit float or fixed-length vector of 16/32-bit float values of length 2/3/4}} + // CHECK: spirv.GL.Exp {{%.*}} : vector<5xf32 %2 = spirv.GL.Exp %arg0 : vector<5xf32> return } diff --git a/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir b/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir -index bb15d018a6c4..f23c2b329a51 100644 +index 2e2fb1a9df32..ad8a66e16745 100644 --- a/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir @@ -21,7 +21,7 @@ spirv.func @f32_to_bf16_vec(%arg0 : vector<2xf32>) "None" { // ----- spirv.func @f32_to_bf16_unsupported(%arg0 : f64) "None" { -- // expected-error @+1 {{operand #0 must be Float32 or vector of Float32 values of length 2/3/4/8/16, but got}} -+ // expected-error @+1 {{op operand #0 must be Float32 or vector of Float32 values of length 2-4294967295, but got 'f64'}} +- // expected-error @+1 {{operand #0 must be Float32 or fixed-length vector of Float32 values of length 2/3/4/8/16, but got}} ++ // expected-error @+1 {{op operand #0 must be Float32 or fixed-length vector of Float32 values of length 2-4294967295, but got 'f64'}} %0 = spirv.INTEL.ConvertFToBF16 %arg0 : f64 to i16 spirv.Return } -@@ -57,7 +57,7 @@ spirv.func @bf16_to_f32_vec(%arg0 : vector<2xi16>) "None" { - // ----- +@@ -58,6 +58,7 @@ spirv.func @bf16_to_f32_vec(%arg0 : vector<2xi16>) "None" { spirv.func @bf16_to_f32_unsupported(%arg0 : i16) "None" { -- // expected-error @+1 {{result #0 must be Float32 or vector of Float32 values of length 2/3/4/8/16, but got}} -+ // expected-error @+1 {{op result #0 must be Float32 or vector of Float32 values of length 2-4294967295, but got 'f16'}} + // expected-error @+1 {{result #0 must be Float32 or fixed-length vector of Float32 values of length 2/3/4/8/16, but got}} ++ // expected-error @+1 {{op result #0 must be Float32 or fixed-length vector of Float32 values of length 2-4294967295, but got 'f16'}} %0 = spirv.INTEL.ConvertBF16ToF %arg0 : i16 to f16 spirv.Return } diff --git a/mlir/test/Dialect/SPIRV/IR/logical-ops.mlir b/mlir/test/Dialect/SPIRV/IR/logical-ops.mlir -index d6c34645f574..c24892a00d5a 100644 +index d7f4ed05969a..3acd5b88e42a 100644 --- a/mlir/test/Dialect/SPIRV/IR/logical-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/logical-ops.mlir -@@ -166,7 +166,7 @@ func.func @logicalUnary(%arg0 : i1) +@@ -184,7 +184,7 @@ func.func @logicalUnary(%arg0 : i1) func.func @logicalUnary(%arg0 : i32) { -- // expected-error @+1 {{'operand' must be bool or vector of bool values of length 2/3/4/8/16, but got 'i32'}} -+ // expected-error @+1 {{'operand' must be bool or vector of bool values of length 2-4294967295, but got 'i32'}} +- // expected-error @+1 {{'operand' must be bool or fixed-length vector of bool values of length 2/3/4/8/16, but got 'i32'}} ++ // expected-error @+1 {{'operand' must be bool or fixed-length vector of bool values of length 2-4294967295, but got 'i32'}} %0 = spirv.LogicalNot %arg0 : i32 return } diff --git a/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir b/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir -index 7ab94f17360d..07d85ca5fa90 100644 +index bdb2abde8d8e..7b9b5d9a4688 100644 --- a/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir @@ -511,7 +511,7 @@ func.func @group_non_uniform_bitwise_and(%val: i32) -> i32 { // ----- func.func @group_non_uniform_bitwise_and(%val: i1) -> i1 { -- // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values of length 2/3/4/8/16, but got 'i1'}} -+ // expected-error @+1 {{op operand #0 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values of length 2-4294967295, but got 'i1'}} +- // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or fixed-length vector of 8/16/32/64-bit integer values of length 2/3/4/8/16, but got 'i1'}} ++ // expected-error @+1 {{op operand #0 must be 8/16/32/64-bit integer or fixed-length vector of 8/16/32/64-bit integer values of length 2-4294967295, but got 'i1'}} %0 = spirv.GroupNonUniformBitwiseAnd %val : i1 -> i1 return %0: i1 } @@ -607,8 +605,8 @@ index 7ab94f17360d..07d85ca5fa90 100644 // ----- func.func @group_non_uniform_bitwise_or(%val: i1) -> i1 { -- // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values of length 2/3/4/8/16, but got 'i1'}} -+ // expected-error @+1 {{op operand #0 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values of length 2-4294967295, but got 'i1'}} +- // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or fixed-length vector of 8/16/32/64-bit integer values of length 2/3/4/8/16, but got 'i1'}} ++ // expected-error @+1 {{op operand #0 must be 8/16/32/64-bit integer or fixed-length vector of 8/16/32/64-bit integer values of length 2-4294967295, but got 'i1'}} %0 = spirv.GroupNonUniformBitwiseOr %val : i1 -> i1 return %0: i1 } @@ -616,8 +614,8 @@ index 7ab94f17360d..07d85ca5fa90 100644 // ----- func.func @group_non_uniform_bitwise_xor(%val: i1) -> i1 { -- // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values of length 2/3/4/8/16, but got 'i1'}} -+ // expected-error @+1 {{op operand #0 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values of length 2-4294967295, but got 'i1'}} +- // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or fixed-length vector of 8/16/32/64-bit integer values of length 2/3/4/8/16, but got 'i1'}} ++ // expected-error @+1 {{op operand #0 must be 8/16/32/64-bit integer or fixed-length vector of 8/16/32/64-bit integer values of length 2-4294967295, but got 'i1'}} %0 = spirv.GroupNonUniformBitwiseXor %val : i1 -> i1 return %0: i1 } @@ -625,8 +623,8 @@ index 7ab94f17360d..07d85ca5fa90 100644 // ----- func.func @group_non_uniform_logical_and(%val: i32) -> i32 { -- // expected-error @+1 {{operand #0 must be bool or vector of bool values of length 2/3/4/8/16, but got 'i32'}} -+ // expected-error @+1 {{op operand #0 must be bool or vector of bool values of length 2-4294967295, but got 'i32'}} +- // expected-error @+1 {{operand #0 must be bool or fixed-length vector of bool values of length 2/3/4/8/16, but got 'i32'}} ++ // expected-error @+1 {{op operand #0 must be bool or fixed-length vector of bool values of length 2-4294967295, but got 'i32'}} %0 = spirv.GroupNonUniformLogicalAnd %val : i32 -> i32 return %0: i32 } @@ -634,8 +632,8 @@ index 7ab94f17360d..07d85ca5fa90 100644 // ----- func.func @group_non_uniform_logical_or(%val: i32) -> i32 { -- // expected-error @+1 {{operand #0 must be bool or vector of bool values of length 2/3/4/8/16, but got 'i32'}} -+ // expected-error @+1 {{op operand #0 must be bool or vector of bool values of length 2-4294967295, but got 'i32'}} +- // expected-error @+1 {{operand #0 must be bool or fixed-length vector of bool values of length 2/3/4/8/16, but got 'i32'}} ++ // expected-error @+1 {{op operand #0 must be bool or fixed-length vector of bool values of length 2-4294967295, but got 'i32'}} %0 = spirv.GroupNonUniformLogicalOr %val : i32 -> i32 return %0: i32 } @@ -643,20 +641,20 @@ index 7ab94f17360d..07d85ca5fa90 100644 // ----- func.func @group_non_uniform_logical_xor(%val: i32) -> i32 { -- // expected-error @+1 {{operand #0 must be bool or vector of bool values of length 2/3/4/8/16, but got 'i32'}} -+ // expected-error @+1 {{op operand #0 must be bool or vector of bool values of length 2-4294967295, but got 'i32'}} +- // expected-error @+1 {{operand #0 must be bool or fixed-length vector of bool values of length 2/3/4/8/16, but got 'i32'}} ++ // expected-error @+1 {{op operand #0 must be bool or fixed-length vector of bool values of length 2-4294967295, but got 'i32'}} %0 = spirv.GroupNonUniformLogicalXor %val : i32 -> i32 return %0: i32 } diff --git a/mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir b/mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir -index 8f021ed3d663..21558b9607f8 100644 +index 6aaaa6012fef..60ef7afeeeed 100644 --- a/mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir @@ -27,7 +27,7 @@ func.func @exp(%arg0 : i32) -> () { // ----- func.func @exp(%arg0 : vector<5xf32>) -> () { -- // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values of length 2/3/4}} +- // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values of length 2/3/4}} + // CHECK: spirv.CL.exp {{%.*}} : vector<5xf32> %2 = spirv.CL.exp %arg0 : vector<5xf32> return @@ -681,7 +679,7 @@ index 8f021ed3d663..21558b9607f8 100644 // ----- -func.func @fabs(%arg0 : vector<5xf32>) -> () { -- // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values of length 2/3/4}} +- // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values of length 2/3/4}} - %2 = spirv.CL.fabs %arg0 : vector<5xf32> - return -} @@ -711,7 +709,7 @@ index 8f021ed3d663..21558b9607f8 100644 // ----- -func.func @sabs(%arg0 : vector<5xi32>) -> () { -- // expected-error @+1 {{op operand #0 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values of length 2/3/4}} +- // expected-error @+1 {{op operand #0 must be 8/16/32/64-bit integer or fixed-length vector of 8/16/32/64-bit integer values of length 2/3/4}} - %2 = spirv.CL.s_abs %arg0 : vector<5xi32> - return -} diff --git a/build_tools/patches/0004-Add-serialization-and-de-serialization-support-for-s.patch b/build_tools/patches/0004-Add-serialization-and-de-serialization-support-for-s.patch index 2b87875d2..44150a759 100644 --- a/build_tools/patches/0004-Add-serialization-and-de-serialization-support-for-s.patch +++ b/build_tools/patches/0004-Add-serialization-and-de-serialization-support-for-s.patch @@ -1,7 +1,7 @@ -From 89e527e48b727a1479aa47fdbe3d2d178d8969a7 Mon Sep 17 00:00:00 2001 +From 5900db1c91d40157c2724d324ea65e22936e3354 Mon Sep 17 00:00:00 2001 From: Garra1980 -Date: Mon, 4 Aug 2025 17:50:56 +0200 -Subject: [PATCH] Add serilialization and deserialization for spirv +Date: Tue, 12 Aug 2025 23:41:51 +0200 +Subject: [PATCH] Add serialization and de-serialization support for spirv --- mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp | 6 ++++++ @@ -9,10 +9,10 @@ Subject: [PATCH] Add serilialization and deserialization for spirv 2 files changed, 12 insertions(+) diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp -index 88931b53a688..f1c22d09cc8e 100644 +index d8c54ec5f88c..3b539382dedd 100644 --- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp +++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp -@@ -282,6 +282,7 @@ LogicalResult spirv::Deserializer::processDecoration(ArrayRef words) { +@@ -283,6 +283,7 @@ LogicalResult spirv::Deserializer::processDecoration(ArrayRef words) { symbol, FPRoundingModeAttr::get(opBuilder.getContext(), static_cast(words[2]))); break; @@ -20,10 +20,10 @@ index 88931b53a688..f1c22d09cc8e 100644 case spirv::Decoration::DescriptorSet: case spirv::Decoration::Binding: if (words.size() != 3) { -@@ -343,6 +344,10 @@ LogicalResult spirv::Deserializer::processDecoration(ArrayRef words) { - case spirv::Decoration::RestrictPointer: - case spirv::Decoration::NoContraction: +@@ -346,6 +347,10 @@ LogicalResult spirv::Deserializer::processDecoration(ArrayRef words) { case spirv::Decoration::Constant: + case spirv::Decoration::Invariant: + case spirv::Decoration::Patch: + case spirv::Decoration::SingleElementVectorINTEL: + case spirv::Decoration::VectorComputeCallableFunctionINTEL: + case spirv::Decoration::VectorComputeFunctionINTEL: @@ -31,7 +31,7 @@ index 88931b53a688..f1c22d09cc8e 100644 if (words.size() != 2) { return emitError(unknownLoc, "OpDecoration with ") << decorationName << "needs a single target "; -@@ -351,6 +356,7 @@ LogicalResult spirv::Deserializer::processDecoration(ArrayRef words) { +@@ -354,6 +359,7 @@ LogicalResult spirv::Deserializer::processDecoration(ArrayRef words) { break; case spirv::Decoration::Location: case spirv::Decoration::SpecId: @@ -40,10 +40,10 @@ index 88931b53a688..f1c22d09cc8e 100644 return emitError(unknownLoc, "OpDecoration with ") << decorationName << "needs a single integer literal"; diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp -index 737f29662f64..cd925b02b6a6 100644 +index 7c007de31558..3aa26ab923a9 100644 --- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp +++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp -@@ -283,8 +283,10 @@ LogicalResult Serializer::processDecorationAttr(Location loc, uint32_t resultID, +@@ -302,8 +302,10 @@ LogicalResult Serializer::processDecorationAttr(Location loc, uint32_t resultID, } return emitError(loc, "expected FPRoundingModeAttr attribute for ") << stringifyDecoration(decoration); @@ -54,17 +54,16 @@ index 737f29662f64..cd925b02b6a6 100644 case spirv::Decoration::Location: if (auto intAttr = dyn_cast(attr)) { args.push_back(intAttr.getValue().getZExtValue()); -@@ -318,6 +320,10 @@ LogicalResult Serializer::processDecorationAttr(Location loc, uint32_t resultID, - case spirv::Decoration::RestrictPointer: - case spirv::Decoration::NoContraction: - case spirv::Decoration::Constant: +@@ -340,6 +342,10 @@ LogicalResult Serializer::processDecorationAttr(Location loc, uint32_t resultID, + case spirv::Decoration::Block: + case spirv::Decoration::Invariant: + case spirv::Decoration::Patch: + case spirv::Decoration::SingleElementVectorINTEL: + case spirv::Decoration::VectorComputeCallableFunctionINTEL: + case spirv::Decoration::VectorComputeFunctionINTEL: + case spirv::Decoration::VectorComputeVariableINTEL: - case spirv::Decoration::Block: // For unit attributes and decoration attributes, the args list // has no values so we do nothing. --- + if (isa(attr)) +-- 2.34.1 - diff --git a/build_tools/patches/0008-xegpu-temporary-downstream-defintion-changes-and-vec.patch b/build_tools/patches/0008-xegpu-temporary-downstream-defintion-changes-and-vec.patch index 530993a30..d27bb5229 100644 --- a/build_tools/patches/0008-xegpu-temporary-downstream-defintion-changes-and-vec.patch +++ b/build_tools/patches/0008-xegpu-temporary-downstream-defintion-changes-and-vec.patch @@ -14,7 +14,7 @@ index 7f4d4f1381df..ebd4f1a3f66a 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td @@ -373,6 +373,7 @@ def XeGPU_LoadNdOp : XeGPU_Op<"load_nd", [ - OptionalAttr: $const_offsets, + OptionalAttr: $const_offsets, OptionalAttr: $packed, OptionalAttr: $transpose, + OptionalAttr: $transpose_bit_width, @@ -24,7 +24,7 @@ index 7f4d4f1381df..ebd4f1a3f66a 100644 @@ -1147,4 +1148,9 @@ def XeGPU_ConvertLayoutOp: XeGPU_Op<"convert_layout", [Pure, AllTypesMatch<["sou let hasCanonicalizer = 1; } - + +def XeGPU_CompileHintOp : XeGPU_Op<"compile_hint", []> { + let summary = "prevents the compiler from scheduling."; + let assemblyFormat = [{ attr-dict }]; @@ -68,27 +68,26 @@ index 33450f3fa229..528b9d55ee61 100644 + kind == CachePolicy::STREAMING || kind == CachePolicy::WRITE_BACK || kind == CachePolicy::WRITE_THROUGH; } - + @@ -419,8 +420,8 @@ void LoadNdOp::build(OpBuilder &builder, OperationState &state, Type retType, xegpu::CachePolicyAttr l3_hint) { - + return build(builder, state, retType, tensorDesc, ValueRange(), - DenseI64ArrayAttr(), packed, transpose, l1_hint, l2_hint, - l3_hint); + DenseI64ArrayAttr(), packed, transpose, nullptr, + l1_hint, l2_hint, l3_hint); } - + LogicalResult LoadNdOp::verify() { @@ -482,7 +483,7 @@ LogicalResult LoadNdOp::verify() { mlir::emitWarning(getLoc()) << "Invalid transpose attr. It is ignored."; } - + - if (getPacked()) { + if (getPacked() || getTransposeBitWidth() == 32) { if (tdescTy.getRank() == 2) { const int axis = 0; auto vnni_factor = valueShape.back(); --- +-- 2.34.1 - diff --git a/lib/Conversion/XeGPUToVC/LSCPatterns.cpp b/lib/Conversion/XeGPUToVC/LSCPatterns.cpp index 421827de9..6128376f4 100644 --- a/lib/Conversion/XeGPUToVC/LSCPatterns.cpp +++ b/lib/Conversion/XeGPUToVC/LSCPatterns.cpp @@ -1198,9 +1198,9 @@ class PrefetchPattern : public OpConversionPattern { // auto l2hint = op.getL2Hint(); auto l3hint = op.getL3Hint(); - auto callOp = genPrefetchIntrinsicCall(rewriter, loc, simd_lanes, l1hint, - l3hint, elemTy, chunkSize, scope, - adaptor.getSource()); + auto callOp = + genPrefetchIntrinsicCall(rewriter, loc, simd_lanes, l1hint, l3hint, + elemTy, chunkSize, scope, adaptor.getSource()); rewriter.replaceOp(op, callOp); return success(); diff --git a/lib/Conversion/XeTileToXeGPU/XeTileToXeGPU.cpp b/lib/Conversion/XeTileToXeGPU/XeTileToXeGPU.cpp index 41c10daeb..dbbd25f7d 100644 --- a/lib/Conversion/XeTileToXeGPU/XeTileToXeGPU.cpp +++ b/lib/Conversion/XeTileToXeGPU/XeTileToXeGPU.cpp @@ -491,10 +491,9 @@ class LoadOpPattern : public OpConversionPattern { auto packAttr = UnitAttr(); auto transAttr = DenseI64ArrayAttr(); auto bitWidthAttr = IntegerAttr(); - auto ldOp = rewriter.create(loc, vecTy, adaptor.getTile(), - ValueRange(), DenseI64ArrayAttr(), - packAttr, transAttr, - bitWidthAttr, L1, L2, L3); + auto ldOp = rewriter.create( + loc, vecTy, adaptor.getTile(), ValueRange(), DenseI64ArrayAttr(), + packAttr, transAttr, bitWidthAttr, L1, L2, L3); llvm::SmallVector results({ldOp.getResult()}); if (memSpace == xegpu::MemorySpace::SLM) { diff --git a/lib/Dialect/NDArray/Extensions/MeshShardingExtensions.cpp b/lib/Dialect/NDArray/Extensions/MeshShardingExtensions.cpp index 0660ecf82..25bb9e0d2 100644 --- a/lib/Dialect/NDArray/Extensions/MeshShardingExtensions.cpp +++ b/lib/Dialect/NDArray/Extensions/MeshShardingExtensions.cpp @@ -100,7 +100,7 @@ static T getBaseShardDimOff(T shard, T numShards, T extend) { } static Sharding ShardingFromOption(const ShardingOption &option, - MLIRContext *ctxt) { + MLIRContext *ctxt) { SmallVector res; for (const auto &v : option.shardingArray) { res.emplace_back(GridAxesAttr::get(ctxt, v)); @@ -141,7 +141,8 @@ getShardingWithShardedDimsOffs(Value ary, OffsetSizeAndStrideOpInterface op) { ShapedType::isDynamicShape(strides)) return op->emitOpError("Dynamic offsets/sizes/strides are not supported"); - auto arySharding = aryShardOp.getSharding().getDefiningOp(); + auto arySharding = + aryShardOp.getSharding().getDefiningOp(); // currently no support for sharding dims sizes on input if (!arySharding.getStaticShardedDimsOffsets().empty()) return op->emitOpError( @@ -190,10 +191,9 @@ getShardingWithShardedDimsOffs(Value ary, OffsetSizeAndStrideOpInterface op) { } } - return Sharding::get( - arySharding.getGridAttr(), arySharding.getSplitAxes().getAxes(), - {}, // static halo - splitOffs, {}, {}); + return Sharding::get(arySharding.getGridAttr(), + arySharding.getSplitAxes().getAxes(), {}, // static halo + splitOffs, {}, {}); } static std::pair diff --git a/lib/Dialect/XeTile/Transforms/Blocking.cpp b/lib/Dialect/XeTile/Transforms/Blocking.cpp index 357d52368..3ce1aaed2 100644 --- a/lib/Dialect/XeTile/Transforms/Blocking.cpp +++ b/lib/Dialect/XeTile/Transforms/Blocking.cpp @@ -1042,8 +1042,8 @@ class RewriteTileReductionOp for (auto v : intermediates) { auto resultTy = VectorType::get({1, 1}, elemTy); for (auto i = 0; i < blkSize[1]; i++) { - auto extractOp = - rewriter.create(loc, v, rewriter.getIndexAttr(i)); + auto extractOp = rewriter.create( + loc, v, rewriter.getIndexAttr(i)); auto splatOp = rewriter.create(op.getLoc(), resultTy, extractOp); newOps.push_back(splatOp); diff --git a/lib/Target/CMakeLists.txt b/lib/Target/CMakeLists.txt index 528b6daa8..8bd310735 100644 --- a/lib/Target/CMakeLists.txt +++ b/lib/Target/CMakeLists.txt @@ -1 +1 @@ -add_subdirectory(LLVM) \ No newline at end of file +add_subdirectory(LLVM) diff --git a/lib/Transforms/OptimizeTranspose.cpp b/lib/Transforms/OptimizeTranspose.cpp index e8df6c8ec..b43b814f8 100644 --- a/lib/Transforms/OptimizeTranspose.cpp +++ b/lib/Transforms/OptimizeTranspose.cpp @@ -516,10 +516,10 @@ struct LoadNdOpPattern : public OpConversionPattern { op.getType().getElementType()); for (auto source : tdescSources) { auto loadNdOp = rewriter.create( - op.getLoc(), newLoadTy, source, - ValueRange(), DenseI64ArrayAttr(), op.getPackedAttr(), - op.getTransposeAttr(), op.getTransposeBitWidthAttr(), - op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr()); + op.getLoc(), newLoadTy, source, ValueRange(), DenseI64ArrayAttr(), + op.getPackedAttr(), op.getTransposeAttr(), + op.getTransposeBitWidthAttr(), op.getL1HintAttr(), op.getL2HintAttr(), + op.getL3HintAttr()); loadNdOps.push_back(loadNdOp); } rewriter.replaceOpWithMultiple(op, {loadNdOps}); @@ -847,10 +847,10 @@ struct TransposeRewritePattern : public OpRewritePattern { rewriter.getIntegerType(32), 32); // need to do a 32 bit transpose to get the packed layout. auto newLoadOp = rewriter.create( - loadOp.getLoc(), newVectorTy, loadOp.getTensorDesc(), - ValueRange(), DenseI64ArrayAttr(), packedAttr, - transposeAttr, transposeBitWidthAttr, loadOp.getL1HintAttr(), - loadOp.getL2HintAttr(), loadOp.getL3HintAttr()); + loadOp.getLoc(), newVectorTy, loadOp.getTensorDesc(), ValueRange(), + DenseI64ArrayAttr(), packedAttr, transposeAttr, transposeBitWidthAttr, + loadOp.getL1HintAttr(), loadOp.getL2HintAttr(), + loadOp.getL3HintAttr()); // Replace the uses of the packed layout conversion with new load. rewriter.replaceAllUsesWith(packedLayoutOps.back()->getResult(0), newLoadOp.getResult()); @@ -872,10 +872,10 @@ struct TransposeRewritePattern : public OpRewritePattern { auto transposeAttr = DenseI64ArrayAttr::get(rewriter.getContext(), {1, 0}); auto newLoadOp = rewriter.create( - loadOp.getLoc(), newVectorTy, loadOp.getTensorDesc(), - ValueRange(), DenseI64ArrayAttr(), packedAttr, - transposeAttr, IntegerAttr(), loadOp.getL1HintAttr(), - loadOp.getL2HintAttr(), loadOp.getL3HintAttr()); + loadOp.getLoc(), newVectorTy, loadOp.getTensorDesc(), ValueRange(), + DenseI64ArrayAttr(), packedAttr, transposeAttr, IntegerAttr(), + loadOp.getL1HintAttr(), loadOp.getL2HintAttr(), + loadOp.getL3HintAttr()); rewriter.replaceAllUsesWith(op.getResult(), newLoadOp.getResult()); } diff --git a/lib/Transforms/RemoveSingleElemVector.cpp b/lib/Transforms/RemoveSingleElemVector.cpp index 0d77454b8..2fcd3d5ae 100644 --- a/lib/Transforms/RemoveSingleElemVector.cpp +++ b/lib/Transforms/RemoveSingleElemVector.cpp @@ -33,8 +33,7 @@ namespace { struct VectorExtractOpConversion final : public mlir::OpConversionPattern { - using mlir::OpConversionPattern< - mlir::vector::ExtractOp>::OpConversionPattern; + using mlir::OpConversionPattern::OpConversionPattern; mlir::LogicalResult matchAndRewrite(mlir::vector::ExtractOp extractOp, OpAdaptor adaptor, @@ -84,8 +83,8 @@ struct VectorExtractStridedSliceConversion final // We only convert ops extracting a single element from a 1D vector. if (resType.getNumElements() == 1 && srcVector.getType().getRank() == 1) { - rewriter.replaceOpWithNewOp( - extractOp, srcVector, offsets[0]); + rewriter.replaceOpWithNewOp(extractOp, srcVector, + offsets[0]); return mlir::success(); } return mlir::failure(); @@ -122,9 +121,8 @@ struct VectorizableOpPattern final }; template -static mlir::Value -createInsertOps(OpTy op, mlir::ValueRange operands, - mlir::ConversionPatternRewriter &rewriter) { +static mlir::Value createInsertOps(OpTy op, mlir::ValueRange operands, + mlir::ConversionPatternRewriter &rewriter) { auto loc = op.getLoc(); auto type = op.getType(); auto elemType = type.getElementType(); @@ -139,8 +137,7 @@ createInsertOps(OpTy op, mlir::ValueRange operands, mlir::Value newOp = rewriter.create(loc, type, denseAttr); for (auto [i, opr] : llvm::enumerate(operands)) { - newOp = - rewriter.create(loc, opr, newOp, i); + newOp = rewriter.create(loc, opr, newOp, i); } return newOp; } @@ -267,7 +264,8 @@ struct RemoveSingleElemVectorPass final return mlir::Value(); return builder - .create(loc, inputs[0], builder.getIndexAttr(0)) + .create(loc, inputs[0], + builder.getIndexAttr(0)) .getResult(); }; diff --git a/test/Conversion/XeTileToXeGPU/sg_scattered_ops.mlir b/test/Conversion/XeTileToXeGPU/sg_scattered_ops.mlir index 21dfd7972..a7d9615f4 100644 --- a/test/Conversion/XeTileToXeGPU/sg_scattered_ops.mlir +++ b/test/Conversion/XeTileToXeGPU/sg_scattered_ops.mlir @@ -64,19 +64,18 @@ gpu.module @test { //CHECK: %[[cast_1:.*]] = memref.cast %[[arg2]] : memref<*xf32> to memref //CHECK: %[[block_id_x:.*]] = gpu.block_id x //CHECK: %[[r0:.*]] = arith.muli %[[block_id_x]], %[[c1024]] : index - //CHECK: %[[r1:.*]] = vector.splat %[[r0]] : vector<1x16xindex> - //CHECK: %[[r2:.*]] = vector.shape_cast %[[r1]] : vector<1x16xindex> to vector<16xindex> - //CHECK: %[[r3:.*]] = xegpu.create_tdesc %[[cast]], %[[r2]] : memref, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>> - //CHECK: %[[r4:.*]] = xegpu.load %[[r3]], %[[cst]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1> -> vector<16xf32> - //CHECK: %[[r5:.*]] = vector.shape_cast %[[r4]] : vector<16xf32> to vector<1x16xf32> - //CHECK: %[[r6:.*]] = xegpu.create_tdesc %[[cast_0]], %[[r2]] : memref, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>> - //CHECK: %[[r7:.*]] = xegpu.load %[[r6]], %[[cst]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1> -> vector<16xf32> - //CHECK: %[[r8:.*]] = vector.shape_cast %[[r7]] : vector<16xf32> to vector<1x16xf32> - //CHECK: %[[r9:.*]] = arith.addf %[[r5]], %[[r8]] : vector<1x16xf32> - //CHECK: %[[r10:.*]] = xegpu.create_tdesc %[[cast_1]], %[[r2]] : memref, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>> - //CHECK: %[[r11:.*]] = vector.shape_cast %[[r9]] : vector<1x16xf32> to vector<16xf32> - //CHECK: xegpu.store %[[r11]], %[[r10]], %[[cst]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1> - //CHECK: xegpu.store %[[r11]], %[[r10]], %[[cst]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1> + //CHECK: %[[r1:.*]] = vector.broadcast %[[r0]] : index to vector<16xindex> + //CHECK: %[[r2:.*]] = xegpu.create_tdesc %[[cast]], %[[r1]] : memref, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>> + //CHECK: %[[r3:.*]] = xegpu.load %[[r2]], %[[cst]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1> -> vector<16xf32> + //CHECK: %[[r4:.*]] = vector.shape_cast %[[r3]] : vector<16xf32> to vector<1x16xf32> + //CHECK: %[[r5:.*]] = xegpu.create_tdesc %[[cast_0]], %[[r1]] : memref, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>> + //CHECK: %[[r6:.*]] = xegpu.load %[[r5]], %[[cst]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1> -> vector<16xf32> + //CHECK: %[[r7:.*]] = vector.shape_cast %[[r6]] : vector<16xf32> to vector<1x16xf32> + //CHECK: %[[r8:.*]] = arith.addf %[[r4]], %[[r7]] : vector<1x16xf32> + //CHECK: %[[r9:.*]] = xegpu.create_tdesc %[[cast_1]], %[[r1]] : memref, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>> + //CHECK: %[[r10:.*]] = vector.shape_cast %[[r8]] : vector<1x16xf32> to vector<16xf32> + //CHECK: xegpu.store %[[r10]], %[[r9]], %[[cst]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1> + //CHECK: xegpu.store %[[r10]], %[[r9]], %[[cst]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1> %c1024 = arith.constant 1024 : index %cst = arith.constant dense : vector<1x32xi1> %cast = memref.cast %arg0 : memref<*xf32> to memref diff --git a/test/Dialect/NDArray/Extensions/lit.local.cfg b/test/Dialect/NDArray/Extensions/lit.local.cfg index 20a13743d..b6d5811eb 100644 --- a/test/Dialect/NDArray/Extensions/lit.local.cfg +++ b/test/Dialect/NDArray/Extensions/lit.local.cfg @@ -2,4 +2,4 @@ local_excludes = ['mesh-spmdization.mlir'] if(not config.imex_enable_excluded_tests): - config.excludes.update(local_excludes) \ No newline at end of file + config.excludes.update(local_excludes) diff --git a/test/Dialect/XeGPU/IR/invalid_vc.mlir b/test/Dialect/XeGPU/IR/invalid_vc.mlir index 7a15bbce2..667d04ea8 100644 --- a/test/Dialect/XeGPU/IR/invalid_vc.mlir +++ b/test/Dialect/XeGPU/IR/invalid_vc.mlir @@ -5,31 +5,17 @@ func.func @test_create_nd_tdesc_vc_1(%src: memref<24xf32>) { %c0 = arith.constant 2 : index %c1 = arith.constant 4 : index - // expected-error@+1 {{expected mixed offsets rank to match mixed sizes rank (2 vs 1) so the rank of the result type is well-formed}} + // expected-error@+1 {{Expecting the rank of shape, strides, offsets, and source (if source is a memref) should match with each other}} %1 = xegpu.create_nd_tdesc %src[%c0, %c1] : memref<24xf32> -> !xegpu.tensor_desc<8x16xf32> return } -// ----- -func.func @test_create_nd_tdesc_vc_3(%input: memref) { - %c0 = arith.constant 2 : index - %c1 = arith.constant 4 : index - - %c8 = arith.constant 8 : index - %c16 = arith.constant 16 : index - - // expected-error@+1 {{expected 1 offset values, got 2}} - %1 = xegpu.create_nd_tdesc %input[%c0, %c1], shape: [%c8, %c16], strides: [%c16, %c1] : memref -> !xegpu.tensor_desc<8x16xf32> - return -} - - // ----- func.func @test_create_nd_tdesc_vc_4(%input: memref) { %c1 = arith.constant 2 : index %c8 = arith.constant 8 : index - // expected-error@+1 {{expected 2 offset values, got 1}} + // expected-error@+1 {{Expecting the TensorDesc rank is not greater than the ranks of shape, strides, offsets or the memref source.}} %1 = xegpu.create_nd_tdesc %input[%c1], shape: [%c8], strides: [%c1] : memref -> !xegpu.tensor_desc<8x16xf32> return diff --git a/test/Transforms/RemoveSingleElemVector/unit_tests.mlir b/test/Transforms/RemoveSingleElemVector/unit_tests.mlir index d0752cf8f..7142472f5 100644 --- a/test/Transforms/RemoveSingleElemVector/unit_tests.mlir +++ b/test/Transforms/RemoveSingleElemVector/unit_tests.mlir @@ -60,7 +60,7 @@ module { %42 = math.exp %26 : vector<1xf16> %43 = math.exp %27 : vector<1xf16> %c0_i32 = arith.constant 0 : i32 - // CHECK-COUNT-16: %{{.*}} = vector.splat %{{.*}} : vector<16xf16> + // CHECK-COUNT-16: %{{.*}} = vector.broadcast %{{.*}} : f16 to vector<16xf16> %44 = vector.extract %28[0] : f16 from vector<1xf16> %45 = vector.splat %44 : vector<16xf16> %46 = vector.extract %29[0] : f16 from vector<1xf16>