1
- From c10fb55e593cff647d3d8835799b3c39d208cfcc Mon Sep 17 00:00:00 2001
1
+ From 46f1b01b6cde4956f4b08985adc59dd530788d4a Mon Sep 17 00:00:00 2001
2
2
From: Garra1980 <
[email protected] >
3
- Date: Fri, 21 Feb 2025 19:20:18 +0100
4
- Subject: [PATCH 1/1 ] Add support for VectorAnyINTEL capability
3
+ Date: Fri, 13 Jun 2025 19:12:11 +0200
4
+ Subject: [PATCH] Add support for VectorAnyINTEL capability
5
5
6
6
---
7
7
.../mlir/Dialect/SPIRV/IR/SPIRVBase.td | 11 +-
@@ -24,13 +24,13 @@ Subject: [PATCH 1/1] Add support for VectorAnyINTEL capability
24
24
17 files changed, 322 insertions(+), 69 deletions(-)
25
25
26
26
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
27
- index cafe14046957..b01e581e0f15 100644
27
+ index d2ba76cdad90..ac491f6068a0 100644
28
28
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
29
29
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
30
- @@ -4163,7 +4163,14 @@ def SPIRV_Int32 : TypeAlias<I32, "Int32">;
31
- def SPIRV_Float32 : TypeAlias<F32, "Float32">;
30
+ @@ -4194,7 +4194,14 @@ def SPIRV_BFloat16KHR : TypeAlias<BF16, "BFloat16">;
32
31
def SPIRV_Float : FloatOfWidths<[16, 32, 64]>;
33
32
def SPIRV_Float16or32 : FloatOfWidths<[16, 32]>;
33
+ def SPIRV_AnyFloat : AnyTypeOf<[SPIRV_Float, SPIRV_BFloat16KHR]>;
34
34
- def SPIRV_Vector : VectorOfLengthAndType<[2, 3, 4, 8, 16],
35
35
+ // Remove the vector size restriction.
36
36
+ // Vector type is quite restrictive in SPIR-V.
@@ -40,10 +40,10 @@ index cafe14046957..b01e581e0f15 100644
40
40
+ // via VectorAnyINTEL capability (SPV_INTEL_vector_compute extension).
41
41
+ // It allows vector length of 2 to 2^32-1.
42
42
+ def SPIRV_Vector : VectorOfLengthRangeAndType<[2, 0xFFFFFFFF],
43
- [SPIRV_Bool, SPIRV_Integer, SPIRV_Float ]>;
43
+ [SPIRV_Bool, SPIRV_Integer, SPIRV_AnyFloat ]>;
44
44
// Component type check is done in the type parser for the following SPIR-V
45
45
// dialect-specific types so we use "Any" here.
46
- @@ -4213 ,7 +4220 ,7 @@ class SPIRV_MatrixOfType<list<Type> allowedTypes> :
46
+ @@ -4245 ,7 +4252 ,7 @@ class SPIRV_MatrixOfType<list<Type> allowedTypes> :
47
47
"Matrix">;
48
48
49
49
class SPIRV_VectorOf<Type type> :
@@ -53,10 +53,10 @@ index cafe14046957..b01e581e0f15 100644
53
53
class SPIRV_ScalarOrVectorOf<Type type> :
54
54
AnyTypeOf<[type, SPIRV_VectorOf<type>]>;
55
55
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
56
- index a18b32253d85..4ba962fbeecc 100644
56
+ index 45ec1846580f..6ca59f91eee9 100644
57
57
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
58
58
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
59
- @@ -643 ,6 +643 ,92 @@ class ScalableVectorOfRankAndLengthAndType<list<int> allowedRanks,
59
+ @@ -648 ,6 +648 ,92 @@ class ScalableVectorOfRankAndLengthAndType<list<int> allowedRanks,
60
60
ScalableVectorOfLength<allowedLengths>.summary,
61
61
"::mlir::VectorType">;
62
62
@@ -150,10 +150,10 @@ index a18b32253d85..4ba962fbeecc 100644
150
150
// Negative values for `n` index in reverse.
151
151
class ShapedTypeWithNthDimOfSize<int n, list<int> allowedSizes> : Type<
152
152
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
153
- index 0cf5f0823be6..92f0319db022 100644
153
+ index a21acef1c4b4..a7c60f6bf1cb 100644
154
154
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
155
155
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
156
- @@ -191 ,9 +191 ,12 @@ static Type parseAndVerifyType(SPIRVDialect const &dialect,
156
+ @@ -188 ,9 +188 ,12 @@ static Type parseAndVerifyType(SPIRVDialect const &dialect,
157
157
parser.emitError(typeLoc, "only 1-D vector allowed but found ") << t;
158
158
return Type();
159
159
}
@@ -169,7 +169,7 @@ index 0cf5f0823be6..92f0319db022 100644
169
169
return Type();
170
170
}
171
171
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
172
- index 337df3a5a65f..275b5aa507fb 100644
172
+ index 93e0c9b33c54..a349da00027e 100644
173
173
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
174
174
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
175
175
@@ -100,9 +100,10 @@ bool CompositeType::classof(Type type) {
@@ -210,10 +210,10 @@ index 337df3a5a65f..275b5aa507fb 100644
210
210
capabilities.push_back(ref);
211
211
}
212
212
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
213
- index c56dbcca2175..03777d4a98b8 100644
213
+ index f5a58c58e05d..406e81235e10 100644
214
214
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
215
215
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
216
- @@ -88 ,9 +88 ,13 @@ static std::optional<SmallVector<int64_t>> getTargetShape(VectorType vecType) {
216
+ @@ -87 ,9 +87 ,13 @@ static std::optional<SmallVector<int64_t>> getTargetShape(VectorType vecType) {
217
217
template <typename LabelT>
218
218
static LogicalResult checkExtensionRequirements(
219
219
LabelT label, const spirv::TargetEnv &targetEnv,
@@ -229,7 +229,7 @@ index c56dbcca2175..03777d4a98b8 100644
229
229
continue;
230
230
231
231
LLVM_DEBUG({
232
- @@ -116 ,9 +120 ,13 @@ static LogicalResult checkExtensionRequirements(
232
+ @@ -115 ,9 +119 ,13 @@ static LogicalResult checkExtensionRequirements(
233
233
template <typename LabelT>
234
234
static LogicalResult checkCapabilityRequirements(
235
235
LabelT label, const spirv::TargetEnv &targetEnv,
@@ -245,7 +245,7 @@ index c56dbcca2175..03777d4a98b8 100644
245
245
continue;
246
246
247
247
LLVM_DEBUG({
248
- @@ -135 ,6 +143 ,55 @@ static LogicalResult checkCapabilityRequirements(
248
+ @@ -134 ,6 +142 ,55 @@ static LogicalResult checkCapabilityRequirements(
249
249
return success();
250
250
}
251
251
@@ -301,7 +301,7 @@ index c56dbcca2175..03777d4a98b8 100644
301
301
/// Returns true if the given `storageClass` needs explicit layout when used in
302
302
/// Shader environments.
303
303
static bool needsExplicitLayout(spirv::StorageClass storageClass) {
304
- @@ -280 ,11 +337 ,14 @@ convertScalarType(const spirv::TargetEnv &targetEnv,
304
+ @@ -279 ,11 +336 ,14 @@ convertScalarType(const spirv::TargetEnv &targetEnv,
305
305
return nullptr;
306
306
}
307
307
@@ -316,7 +316,7 @@ index c56dbcca2175..03777d4a98b8 100644
316
316
auto intType = cast<IntegerType>(type);
317
317
LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n");
318
318
return IntegerType::get(targetEnv.getContext(), /*width=*/32,
319
- @@ -359 ,10 +419 ,13 @@ convertVectorType(const spirv::TargetEnv &targetEnv,
319
+ @@ -358 ,10 +418 ,13 @@ convertVectorType(const spirv::TargetEnv &targetEnv,
320
320
321
321
if (type.getRank() <= 1 && type.getNumElements() == 1)
322
322
return elementType;
@@ -334,7 +334,7 @@ index c56dbcca2175..03777d4a98b8 100644
334
334
return nullptr;
335
335
}
336
336
337
- @@ -384 ,16 +447 ,40 @@ convertVectorType(const spirv::TargetEnv &targetEnv,
337
+ @@ -383 ,16 +446 ,40 @@ convertVectorType(const spirv::TargetEnv &targetEnv,
338
338
cast<spirv::CompositeType>(type).getExtensions(extensions, storageClass);
339
339
cast<spirv::CompositeType>(type).getCapabilities(capabilities, storageClass);
340
340
@@ -382,7 +382,7 @@ index c56dbcca2175..03777d4a98b8 100644
382
382
}
383
383
384
384
static Type
385
- @@ -1562 ,16 +1649 ,18 @@ bool SPIRVConversionTarget::isLegalOp(Operation *op) {
385
+ @@ -1563 ,16 +1650 ,18 @@ bool SPIRVConversionTarget::isLegalOp(Operation *op) {
386
386
SmallVector<ArrayRef<spirv::Extension>, 4> typeExtensions;
387
387
SmallVector<ArrayRef<spirv::Capability>, 8> typeCapabilities;
388
388
for (Type valueType : valueTypes) {
@@ -471,10 +471,10 @@ index 1abe0fd2ec46..f64436fa2632 100644
471
471
}
472
472
473
473
diff --git a/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir b/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir
474
- index 82d750755ffe..6f364c5b0875 100644
474
+ index 1737f4a906bf..13f4e17167ef 100644
475
475
--- a/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir
476
476
+++ b/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir
477
- @@ -351 ,8 +351 ,21 @@ module attributes {
477
+ @@ -345 ,8 +345 ,21 @@ module attributes {
478
478
spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [], []>, #spirv.resource_limits<>>
479
479
} {
480
480
@@ -499,10 +499,10 @@ index 82d750755ffe..6f364c5b0875 100644
499
499
} // end module
500
500
501
501
diff --git a/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir b/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir
502
- index 2d0c86e08de5..f60bd10c115b 100644
502
+ index d58c27598f2b..4c22244c08e0 100644
503
503
--- a/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir
504
504
+++ b/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir
505
- @@ -283 ,7 +283 ,7 @@ func.func @dot(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> f16 {
505
+ @@ -339 ,7 +339 ,7 @@ func.func @dot(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> f16 {
506
506
// -----
507
507
508
508
func.func @dot(%arg0: vector<4xi32>, %arg1: vector<4xi32>) -> i32 {
@@ -543,7 +543,7 @@ index f3f0ebf60f46..1138f38bcef2 100644
543
543
return %0 : f16
544
544
}
545
545
diff --git a/mlir/test/Dialect/SPIRV/IR/gl-ops.mlir b/mlir/test/Dialect/SPIRV/IR/gl-ops.mlir
546
- index beda3872bc8d..75e4c1b9a43d 100644
546
+ index 642346cc40b0..10ede222ada7 100644
547
547
--- a/mlir/test/Dialect/SPIRV/IR/gl-ops.mlir
548
548
+++ b/mlir/test/Dialect/SPIRV/IR/gl-ops.mlir
549
549
@@ -27,7 +27,7 @@ func.func @exp(%arg0 : i32) -> () {
@@ -578,7 +578,7 @@ index bb15d018a6c4..f23c2b329a51 100644
578
578
spirv.Return
579
579
}
580
580
diff --git a/mlir/test/Dialect/SPIRV/IR/logical-ops.mlir b/mlir/test/Dialect/SPIRV/IR/logical-ops.mlir
581
- index 5c24f0e6a7d3..5cbdc5e1e5ef 100644
581
+ index d6c34645f574..c24892a00d5a 100644
582
582
--- a/mlir/test/Dialect/SPIRV/IR/logical-ops.mlir
583
583
+++ b/mlir/test/Dialect/SPIRV/IR/logical-ops.mlir
584
584
@@ -166,7 +166,7 @@ func.func @logicalUnary(%arg0 : i1)
@@ -591,10 +591,10 @@ index 5c24f0e6a7d3..5cbdc5e1e5ef 100644
591
591
return
592
592
}
593
593
diff --git a/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir b/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir
594
- index 60ae1584d29f..ac6598b42b03 100644
594
+ index 7ab94f17360d..07d85ca5fa90 100644
595
595
--- a/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir
596
596
+++ b/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir
597
- @@ -495 ,7 +495 ,7 @@ func.func @group_non_uniform_bitwise_and(%val: i32) -> i32 {
597
+ @@ -511 ,7 +511 ,7 @@ func.func @group_non_uniform_bitwise_and(%val: i32) -> i32 {
598
598
// -----
599
599
600
600
func.func @group_non_uniform_bitwise_and(%val: i1) -> i1 {
@@ -603,7 +603,7 @@ index 60ae1584d29f..ac6598b42b03 100644
603
603
%0 = spirv.GroupNonUniformBitwiseAnd <Workgroup> <Reduce> %val : i1 -> i1
604
604
return %0: i1
605
605
}
606
- @@ -516 ,7 +516 ,7 @@ func.func @group_non_uniform_bitwise_or(%val: i32) -> i32 {
606
+ @@ -532 ,7 +532 ,7 @@ func.func @group_non_uniform_bitwise_or(%val: i32) -> i32 {
607
607
// -----
608
608
609
609
func.func @group_non_uniform_bitwise_or(%val: i1) -> i1 {
@@ -612,7 +612,7 @@ index 60ae1584d29f..ac6598b42b03 100644
612
612
%0 = spirv.GroupNonUniformBitwiseOr <Workgroup> <Reduce> %val : i1 -> i1
613
613
return %0: i1
614
614
}
615
- @@ -537 ,7 +537 ,7 @@ func.func @group_non_uniform_bitwise_xor(%val: i32) -> i32 {
615
+ @@ -553 ,7 +553 ,7 @@ func.func @group_non_uniform_bitwise_xor(%val: i32) -> i32 {
616
616
// -----
617
617
618
618
func.func @group_non_uniform_bitwise_xor(%val: i1) -> i1 {
@@ -621,7 +621,7 @@ index 60ae1584d29f..ac6598b42b03 100644
621
621
%0 = spirv.GroupNonUniformBitwiseXor <Workgroup> <Reduce> %val : i1 -> i1
622
622
return %0: i1
623
623
}
624
- @@ -558 ,7 +558 ,7 @@ func.func @group_non_uniform_logical_and(%val: i1) -> i1 {
624
+ @@ -574 ,7 +574 ,7 @@ func.func @group_non_uniform_logical_and(%val: i1) -> i1 {
625
625
// -----
626
626
627
627
func.func @group_non_uniform_logical_and(%val: i32) -> i32 {
@@ -630,7 +630,7 @@ index 60ae1584d29f..ac6598b42b03 100644
630
630
%0 = spirv.GroupNonUniformLogicalAnd <Workgroup> <Reduce> %val : i32 -> i32
631
631
return %0: i32
632
632
}
633
- @@ -579 ,7 +579 ,7 @@ func.func @group_non_uniform_logical_or(%val: i1) -> i1 {
633
+ @@ -595 ,7 +595 ,7 @@ func.func @group_non_uniform_logical_or(%val: i1) -> i1 {
634
634
// -----
635
635
636
636
func.func @group_non_uniform_logical_or(%val: i32) -> i32 {
@@ -639,7 +639,7 @@ index 60ae1584d29f..ac6598b42b03 100644
639
639
%0 = spirv.GroupNonUniformLogicalOr <Workgroup> <Reduce> %val : i32 -> i32
640
640
return %0: i32
641
641
}
642
- @@ -600 ,7 +600 ,7 @@ func.func @group_non_uniform_logical_xor(%val: i1) -> i1 {
642
+ @@ -616 ,7 +616 ,7 @@ func.func @group_non_uniform_logical_xor(%val: i1) -> i1 {
643
643
// -----
644
644
645
645
func.func @group_non_uniform_logical_xor(%val: i32) -> i32 {
0 commit comments