1
- From 46f1b01b6cde4956f4b08985adc59dd530788d4a Mon Sep 17 00:00:00 2001
1
+ From 4167e203a75627ca13d8ea7560aaea9a6bb506f0 Mon Sep 17 00:00:00 2001
2
2
From: Garra1980 <
[email protected] >
3
- Date: Fri, 13 Jun 2025 19:12:11 +0200
3
+ Date: Sat, 12 Jul 2025 00:39:57 +0200
4
4
Subject: [PATCH] Add support for VectorAnyINTEL capability
5
5
6
6
---
@@ -24,10 +24,10 @@ Subject: [PATCH] Add support for VectorAnyINTEL capability
24
24
17 files changed, 322 insertions(+), 69 deletions(-)
25
25
26
26
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
27
- index d2ba76cdad90..ac491f6068a0 100644
27
+ index 910418f1706a..29af93d8e752 100644
28
28
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
29
29
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
30
- @@ -4194 ,7 +4194 ,14 @@ def SPIRV_BFloat16KHR : TypeAlias<BF16, "BFloat16">;
30
+ @@ -4217 ,7 +4217 ,14 @@ def SPIRV_BFloat16KHR : TypeAlias<BF16, "BFloat16">;
31
31
def SPIRV_Float : FloatOfWidths<[16, 32, 64]>;
32
32
def SPIRV_Float16or32 : FloatOfWidths<[16, 32]>;
33
33
def SPIRV_AnyFloat : AnyTypeOf<[SPIRV_Float, SPIRV_BFloat16KHR]>;
@@ -43,7 +43,7 @@ index d2ba76cdad90..ac491f6068a0 100644
43
43
[SPIRV_Bool, SPIRV_Integer, SPIRV_AnyFloat]>;
44
44
// Component type check is done in the type parser for the following SPIR-V
45
45
// dialect-specific types so we use "Any" here.
46
- @@ -4245 ,7 +4252 ,7 @@ class SPIRV_MatrixOfType<list<Type> allowedTypes> :
46
+ @@ -4270 ,7 +4277 ,7 @@ class SPIRV_MatrixOfType<list<Type> allowedTypes> :
47
47
"Matrix">;
48
48
49
49
class SPIRV_VectorOf<Type type> :
@@ -150,7 +150,7 @@ index 45ec1846580f..6ca59f91eee9 100644
150
150
// Negative values for `n` index in reverse.
151
151
class ShapedTypeWithNthDimOfSize<int n, list<int> allowedSizes> : Type<
152
152
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
153
- index a21acef1c4b4..a7c60f6bf1cb 100644
153
+ index 88c7adf3dfcb..d29c88a1fd53 100644
154
154
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
155
155
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
156
156
@@ -188,9 +188,12 @@ static Type parseAndVerifyType(SPIRVDialect const &dialect,
@@ -169,10 +169,10 @@ index a21acef1c4b4..a7c60f6bf1cb 100644
169
169
return Type();
170
170
}
171
171
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
172
- index 93e0c9b33c54..a349da00027e 100644
172
+ index 2b90df42af5c..34f25f2b3bc9 100644
173
173
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
174
174
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
175
- @@ -100 ,9 +100 ,10 @@ bool CompositeType::classof(Type type) {
175
+ @@ -101 ,9 +101 ,10 @@ bool CompositeType::classof(Type type) {
176
176
}
177
177
178
178
bool CompositeType::isValid(VectorType type) {
@@ -186,7 +186,7 @@ index 93e0c9b33c54..a349da00027e 100644
186
186
}
187
187
188
188
Type CompositeType::getElementType(unsigned index) const {
189
- @@ -164 ,7 +165 ,21 @@ void CompositeType::getCapabilities(
189
+ @@ -174 ,7 +175 ,21 @@ void CompositeType::getCapabilities(
190
190
.Case<VectorType>([&](VectorType type) {
191
191
auto vecSize = getNumElements();
192
192
if (vecSize == 8 || vecSize == 16) {
@@ -210,7 +210,7 @@ index 93e0c9b33c54..a349da00027e 100644
210
210
capabilities.push_back(ref);
211
211
}
212
212
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
213
- index f5a58c58e05d..406e81235e10 100644
213
+ index 1e7bb046d375..24e633da72aa 100644
214
214
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
215
215
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
216
216
@@ -87,9 +87,13 @@ static std::optional<SmallVector<int64_t>> getTargetShape(VectorType vecType) {
@@ -499,15 +499,15 @@ index 1737f4a906bf..13f4e17167ef 100644
499
499
} // end module
500
500
501
501
diff --git a/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir b/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir
502
- index d58c27598f2b..4c22244c08e0 100644
502
+ index 3adafc15c79f..f75fd6cb0d39 100644
503
503
--- a/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir
504
504
+++ b/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir
505
- @@ -339 ,7 +339 ,7 @@ func.func @dot(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> f16 {
505
+ @@ -348 ,7 +348 ,7 @@ func.func @dot(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> f16 {
506
506
// -----
507
507
508
508
func.func @dot(%arg0: vector<4xi32>, %arg1: vector<4xi32>) -> i32 {
509
- - // expected-error @+1 {{'spirv.Dot' op operand #0 must be vector of 16/32/64-bit float values of length 2/3/4/8/16}}
510
- + // expected-error @+1 {{op operand #0 must be vector of 16/32/64-bit float values of length 2-4294967295, but got 'vector<4xi32>'}}
509
+ - // expected-error @+1 {{'spirv.Dot' op operand #0 must be vector of 16/32/64-bit float or BFloat16 values of length 2/3/4/8/16}}
510
+ + // expected-error @+1 {{op operand #0 must be vector of 16/32/64-bit float or BFloat16 values of length 2-4294967295, but got 'vector<4xi32>'}}
511
511
%0 = spirv.Dot %arg0, %arg1 : vector<4xi32> -> i32
512
512
return %0 : i32
513
513
}
@@ -543,7 +543,7 @@ index f3f0ebf60f46..1138f38bcef2 100644
543
543
return %0 : f16
544
544
}
545
545
diff --git a/mlir/test/Dialect/SPIRV/IR/gl-ops.mlir b/mlir/test/Dialect/SPIRV/IR/gl-ops.mlir
546
- index 642346cc40b0..10ede222ada7 100644
546
+ index 5c5d94c40e57..8edaa3762c23 100644
547
547
--- a/mlir/test/Dialect/SPIRV/IR/gl-ops.mlir
548
548
+++ b/mlir/test/Dialect/SPIRV/IR/gl-ops.mlir
549
549
@@ -27,7 +27,7 @@ func.func @exp(%arg0 : i32) -> () {
@@ -722,7 +722,7 @@ index 8f021ed3d663..21558b9607f8 100644
722
722
// expected-error @+1 {{expected ':'}}
723
723
%2 = spirv.CL.s_abs %arg0, %arg1 : i32
724
724
diff --git a/mlir/test/Target/SPIRV/arithmetic-ops.mlir b/mlir/test/Target/SPIRV/arithmetic-ops.mlir
725
- index b1ea13c6854f..90144afc6f3a 100644
725
+ index b80e17f979da..32103f7b9c57 100644
726
726
--- a/mlir/test/Target/SPIRV/arithmetic-ops.mlir
727
727
+++ b/mlir/test/Target/SPIRV/arithmetic-ops.mlir
728
728
@@ -6,9 +6,9 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
0 commit comments