1
- From 94cc2bb6a778cad3b762244d6d78ecf2e19b5372 Mon Sep 17 00:00:00 2001
2
- From: Md Abdullah Shahneous Bari <
[email protected] >
3
- Date: Fri, 26 Apr 2024 20:20:28 +0000
4
- Subject: [PATCH 1/7] Add-support-for-VectorAnyINTEL-capability
5
-
6
- Allow vector of any lengths between [2-2^63-1].
7
- VectorAnyINTEL capability (part of "SPV_INTEL_vector_compute" extension)
8
- relaxes the length constraint on SPIR-V vector sizes from 2,3, and 4.
9
-
10
- Also add support for following:
11
-
12
- - Add support for capability inferred extension requirement checking.
13
- If a capability is a requirement, the respective extension that implements
14
- it should also become an extension requirement, there were no support for
15
- that check, as a result, the extension requirement had to be added separately.
16
- This separate requirement addition causes problem when a feature is enabled by
17
- multiple capability, and one of the capability is part of an extension. E.g.,
18
- vector size of 16 can be enabled by both "Vector16" and "vectorAnyINTEL"
19
- capability, however, only "vectorAnyINTEL" has an extension requirement
20
- ("SPV_INTEL_vector_compute"). Since the process of adding capability
21
- and extension requirement are independent, there is no way, to handle
22
- cases like this. Therefore, for cases like this, enable adding capability
23
- requirement initially, then do the check for capability inferred extension.
24
-
25
- - Add support for optionally skipping capability and extension requirement
1
+ From 45b150c9a0c4e4bd60c153e5142da17fd6cde6da Mon Sep 17 00:00:00 2001
2
+ From: izamyati <
[email protected] >
3
+ Date: Tue, 24 Sep 2024 17:42:02 -0500
4
+ Subject: [PATCH] Add support for VectorAnyINTEL capability
26
5
27
6
---
28
7
.../mlir/Dialect/SPIRV/IR/SPIRVBase.td | 9 +-
29
8
mlir/include/mlir/IR/CommonTypeConstraints.td | 86 ++++++++++++
30
9
mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp | 7 +-
31
10
mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp | 24 +++-
32
- .../SPIRV/Transforms/SPIRVConversion.cpp | 132 +++++++++++++++---
11
+ .../SPIRV/Transforms/SPIRVConversion.cpp | 126 +++++++++++++++---
33
12
.../arith-to-spirv-unsupported.mlir | 4 +-
34
13
.../ArithToSPIRV/arith-to-spirv.mlir | 34 +++++
35
14
.../FuncToSPIRV/types-to-spirv.mlir | 17 ++-
@@ -42,13 +21,13 @@ requirement initially, then do the check for capability inferred extension.
42
21
mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir | 34 ++---
43
22
mlir/test/Target/SPIRV/arithmetic-ops.mlir | 6 +-
44
23
mlir/test/Target/SPIRV/ocl-ops.mlir | 6 +
45
- 17 files changed, 319 insertions(+), 68 deletions(-)
24
+ 17 files changed, 316 insertions(+), 65 deletions(-)
46
25
47
26
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
48
- index 6ec97e17c5dc..75e42c024553 100644
27
+ index 3b7da9b44a08..ddaeb13ef253 100644
49
28
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
50
29
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
51
- @@ -4138 ,7 +4138 ,12 @@ def SPIRV_Int32 : TypeAlias<I32, "Int32">;
30
+ @@ -4142 ,7 +4142 ,12 @@ def SPIRV_Int32 : TypeAlias<I32, "Int32">;
52
31
def SPIRV_Float32 : TypeAlias<F32, "Float32">;
53
32
def SPIRV_Float : FloatOfWidths<[16, 32, 64]>;
54
33
def SPIRV_Float16or32 : FloatOfWidths<[16, 32]>;
@@ -62,8 +41,8 @@ index 6ec97e17c5dc..75e42c024553 100644
62
41
[SPIRV_Bool, SPIRV_Integer, SPIRV_Float]>;
63
42
// Component type check is done in the type parser for the following SPIR-V
64
43
// dialect-specific types so we use "Any" here.
65
- @@ -4189 ,7 +4194 ,7 @@ class SPIRV_JointMatrixOfType <list<Type> allowedTypes> :
66
- "Joint Matrix">;
44
+ @@ -4185 ,7 +4190 ,7 @@ class SPIRV_CoopMatrixOfType <list<Type> allowedTypes> :
45
+ "Cooperative Matrix">;
67
46
68
47
class SPIRV_VectorOf<Type type> :
69
48
- VectorOfLengthAndType<[2, 3, 4, 8,16], [type]>;
@@ -72,10 +51,10 @@ index 6ec97e17c5dc..75e42c024553 100644
72
51
class SPIRV_ScalarOrVectorOf<Type type> :
73
52
AnyTypeOf<[type, SPIRV_VectorOf<type>]>;
74
53
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
75
- index af4f13dc0936..28d49d9e91f0 100644
54
+ index 211385245555..671ec270efe0 100644
76
55
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
77
56
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
78
- @@ -608 ,6 +608 ,92 @@ class ScalableVectorOfRankAndLengthAndType<list<int> allowedRanks,
57
+ @@ -637 ,6 +637 ,92 @@ class ScalableVectorOfRankAndLengthAndType<list<int> allowedRanks,
79
58
ScalableVectorOfLength<allowedLengths>.summary,
80
59
"::mlir::VectorType">;
81
60
@@ -169,7 +148,7 @@ index af4f13dc0936..28d49d9e91f0 100644
169
148
// Negative values for `n` index in reverse.
170
149
class ShapedTypeWithNthDimOfSize<int n, list<int> allowedSizes> : Type<
171
150
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
172
- index 72488d6e5d0b..b38f20458d32 100644
151
+ index 48be287ef833..aec6d64209dd 100644
173
152
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
174
153
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
175
154
@@ -187,9 +187,12 @@ static Type parseAndVerifyType(SPIRVDialect const &dialect,
@@ -188,7 +167,7 @@ index 72488d6e5d0b..b38f20458d32 100644
188
167
return Type();
189
168
}
190
169
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
191
- index 3f25696aa5eb..2d64fea0dc26 100644
170
+ index 337df3a5a65f..542c6beba2e4 100644
192
171
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
193
172
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
194
173
@@ -100,9 +100,11 @@ bool CompositeType::classof(Type type) {
@@ -206,7 +185,7 @@ index 3f25696aa5eb..2d64fea0dc26 100644
206
185
}
207
186
208
187
Type CompositeType::getElementType(unsigned index) const {
209
- @@ -170 ,7 +172 ,21 @@ void CompositeType::getCapabilities(
188
+ @@ -164 ,7 +166 ,21 @@ void CompositeType::getCapabilities(
210
189
.Case<VectorType>([&](VectorType type) {
211
190
auto vecSize = getNumElements();
212
191
if (vecSize == 8 || vecSize == 16) {
@@ -230,10 +209,10 @@ index 3f25696aa5eb..2d64fea0dc26 100644
230
209
capabilities.push_back(ref);
231
210
}
232
211
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
233
- index 4072608dc8f8..3fc675632970 100644
212
+ index d833ec9309ba..36840582a114 100644
234
213
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
235
214
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
236
- @@ -43 ,9 +43 ,13 @@ using namespace mlir;
215
+ @@ -88 ,9 +88 ,13 @@ static std::optional<SmallVector<int64_t>> getTargetShape(VectorType vecType) {
237
216
template <typename LabelT>
238
217
static LogicalResult checkExtensionRequirements(
239
218
LabelT label, const spirv::TargetEnv &targetEnv,
@@ -249,7 +228,7 @@ index 4072608dc8f8..3fc675632970 100644
249
228
continue;
250
229
251
230
LLVM_DEBUG({
252
- @@ -71 ,9 +75 ,13 @@ static LogicalResult checkExtensionRequirements(
231
+ @@ -116 ,9 +120 ,13 @@ static LogicalResult checkExtensionRequirements(
253
232
template <typename LabelT>
254
233
static LogicalResult checkCapabilityRequirements(
255
234
LabelT label, const spirv::TargetEnv &targetEnv,
@@ -265,7 +244,7 @@ index 4072608dc8f8..3fc675632970 100644
265
244
continue;
266
245
267
246
LLVM_DEBUG({
268
- @@ -90 ,6 +98 ,55 @@ static LogicalResult checkCapabilityRequirements(
247
+ @@ -135 ,6 +143 ,55 @@ static LogicalResult checkCapabilityRequirements(
269
248
return success();
270
249
}
271
250
@@ -321,27 +300,24 @@ index 4072608dc8f8..3fc675632970 100644
321
300
/// Returns true if the given `storageClass` needs explicit layout when used in
322
301
/// Shader environments.
323
302
static bool needsExplicitLayout(spirv::StorageClass storageClass) {
324
- @@ -247,12 +304,17 @@ convertScalarType(const spirv::TargetEnv &targetEnv,
303
+ @@ -280,11 +337,16 @@ convertScalarType(const spirv::TargetEnv &targetEnv,
325
304
return nullptr;
326
305
}
327
306
328
- - if (auto floatType = dyn_cast<FloatType>(type)) {
329
307
+ //if (auto floatType = dyn_cast<FloatType>(type)) {
330
308
+ // Convert to 32-bit float and remove floatType related capability
331
309
+ // restriction
332
- + if (auto floatType = dyn_cast<FloatType>(type)) {
310
+ if (auto floatType = dyn_cast<FloatType>(type)) {
333
311
LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n");
334
312
return Builder(targetEnv.getContext()).getF32Type();
335
313
}
336
314
337
- - auto intType = cast<IntegerType>(type);
338
315
+ //auto intType = cast<IntegerType>(type);
339
316
+ // Convert to 32-bit int and remove intType related capability restriction
340
- + auto intType = cast<IntegerType>(type);
317
+ auto intType = cast<IntegerType>(type);
341
318
LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n");
342
319
return IntegerType::get(targetEnv.getContext(), /*width=*/32,
343
- intType.getSignedness());
344
- @@ -342,16 +404,40 @@ convertVectorType(const spirv::TargetEnv &targetEnv,
320
+ @@ -375,16 +437,40 @@ convertVectorType(const spirv::TargetEnv &targetEnv,
345
321
cast<spirv::CompositeType>(type).getExtensions(extensions, storageClass);
346
322
cast<spirv::CompositeType>(type).getCapabilities(capabilities, storageClass);
347
323
@@ -389,7 +365,7 @@ index 4072608dc8f8..3fc675632970 100644
389
365
}
390
366
391
367
static Type
392
- @@ -1163 ,16 +1249 ,18 @@ bool SPIRVConversionTarget::isLegalOp(Operation *op) {
368
+ @@ -1553 ,16 +1639 ,18 @@ bool SPIRVConversionTarget::isLegalOp(Operation *op) {
393
369
SmallVector<ArrayRef<spirv::Extension>, 4> typeExtensions;
394
370
SmallVector<ArrayRef<spirv::Capability>, 8> typeCapabilities;
395
371
for (Type valueType : valueTypes) {
@@ -400,10 +376,9 @@ index 4072608dc8f8..3fc675632970 100644
400
376
- return false;
401
377
-
402
378
typeCapabilities.clear();
403
- - cast<spirv::SPIRVType>(valueType).getCapabilities(typeCapabilities);
379
+ cast<spirv::SPIRVType>(valueType).getCapabilities(typeCapabilities);
404
380
- if (failed(checkCapabilityRequirements(op->getName(), this->targetEnv,
405
381
- typeCapabilities)))
406
- + cast<spirv::SPIRVType>(valueType).getCapabilities(typeCapabilities);
407
382
+ typeExtensions.clear();
408
383
+ cast<spirv::SPIRVType>(valueType).getExtensions(typeExtensions);
409
384
+ // Checking for capability and extension requirements along with capability
@@ -418,10 +393,10 @@ index 4072608dc8f8..3fc675632970 100644
418
393
}
419
394
420
395
diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir
421
- index 0d92a8e676d8..d61ace8d6876 100644
396
+ index 24a0bab352c3..96b8ea6e7975 100644
422
397
--- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir
423
398
+++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir
424
- @@ -11 ,9 +11 ,9 @@ module attributes {
399
+ @@ -28 ,9 +28 ,9 @@ module attributes {
425
400
#spirv.vce<v1.0, [Int8, Int16, Int64, Float16, Float64, Shader], []>, #spirv.resource_limits<>>
426
401
} {
427
402
@@ -434,10 +409,10 @@ index 0d92a8e676d8..d61ace8d6876 100644
434
409
}
435
410
436
411
diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
437
- index ae47ae36ca51..644996fe0fa7 100644
412
+ index 1abe0fd2ec46..e485296ad026 100644
438
413
--- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
439
414
+++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
440
- @@ -1447 ,6 +1447 ,40 @@ func.func @ops_flags(%arg0: i64, %arg1: i64) {
415
+ @@ -1462 ,6 +1462 ,40 @@ func.func @ops_flags(%arg0: i64, %arg1: i64) {
441
416
%2 = arith.muli %arg0, %arg1 overflow<nsw, nuw> : i64
442
417
// CHECK: %{{.*}} = spirv.IMul %{{.*}}, %{{.*}} : i64
443
418
%3 = arith.muli %arg0, %arg1 overflow<nsw, nuw> : i64
@@ -586,7 +561,7 @@ index 53a1015de75b..6970b8ec0628 100644
586
561
spirv.Return
587
562
}
588
563
diff --git a/mlir/test/Dialect/SPIRV/IR/logical-ops.mlir b/mlir/test/Dialect/SPIRV/IR/logical-ops.mlir
589
- index 7dc0bd99f54b..5dd9901828cd 100644
564
+ index 5c24f0e6a7d3..3ca61ab48096 100644
590
565
--- a/mlir/test/Dialect/SPIRV/IR/logical-ops.mlir
591
566
+++ b/mlir/test/Dialect/SPIRV/IR/logical-ops.mlir
592
567
@@ -166,7 +166,7 @@ func.func @logicalUnary(%arg0 : i1)
@@ -599,10 +574,10 @@ index 7dc0bd99f54b..5dd9901828cd 100644
599
574
return
600
575
}
601
576
diff --git a/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir b/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir
602
- index f7fd05b36bae..5228bb719d94 100644
577
+ index d8a26c71d12f..d22378817dbb 100644
603
578
--- a/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir
604
579
+++ b/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir
605
- @@ -439 ,7 +439 ,7 @@ func.func @group_non_uniform_bitwise_and(%val: i32) -> i32 {
580
+ @@ -495 ,7 +495 ,7 @@ func.func @group_non_uniform_bitwise_and(%val: i32) -> i32 {
606
581
// -----
607
582
608
583
func.func @group_non_uniform_bitwise_and(%val: i1) -> i1 {
@@ -611,7 +586,7 @@ index f7fd05b36bae..5228bb719d94 100644
611
586
%0 = spirv.GroupNonUniformBitwiseAnd "Workgroup" "Reduce" %val : i1
612
587
return %0: i1
613
588
}
614
- @@ -460 ,7 +460 ,7 @@ func.func @group_non_uniform_bitwise_or(%val: i32) -> i32 {
589
+ @@ -516 ,7 +516 ,7 @@ func.func @group_non_uniform_bitwise_or(%val: i32) -> i32 {
615
590
// -----
616
591
617
592
func.func @group_non_uniform_bitwise_or(%val: i1) -> i1 {
@@ -620,7 +595,7 @@ index f7fd05b36bae..5228bb719d94 100644
620
595
%0 = spirv.GroupNonUniformBitwiseOr "Workgroup" "Reduce" %val : i1
621
596
return %0: i1
622
597
}
623
- @@ -481 ,7 +481 ,7 @@ func.func @group_non_uniform_bitwise_xor(%val: i32) -> i32 {
598
+ @@ -537 ,7 +537 ,7 @@ func.func @group_non_uniform_bitwise_xor(%val: i32) -> i32 {
624
599
// -----
625
600
626
601
func.func @group_non_uniform_bitwise_xor(%val: i1) -> i1 {
@@ -629,7 +604,7 @@ index f7fd05b36bae..5228bb719d94 100644
629
604
%0 = spirv.GroupNonUniformBitwiseXor "Workgroup" "Reduce" %val : i1
630
605
return %0: i1
631
606
}
632
- @@ -502 ,7 +502 ,7 @@ func.func @group_non_uniform_logical_and(%val: i1) -> i1 {
607
+ @@ -558 ,7 +558 ,7 @@ func.func @group_non_uniform_logical_and(%val: i1) -> i1 {
633
608
// -----
634
609
635
610
func.func @group_non_uniform_logical_and(%val: i32) -> i32 {
@@ -638,7 +613,7 @@ index f7fd05b36bae..5228bb719d94 100644
638
613
%0 = spirv.GroupNonUniformLogicalAnd "Workgroup" "Reduce" %val : i32
639
614
return %0: i32
640
615
}
641
- @@ -523 ,7 +523 ,7 @@ func.func @group_non_uniform_logical_or(%val: i1) -> i1 {
616
+ @@ -579 ,7 +579 ,7 @@ func.func @group_non_uniform_logical_or(%val: i1) -> i1 {
642
617
// -----
643
618
644
619
func.func @group_non_uniform_logical_or(%val: i32) -> i32 {
@@ -647,7 +622,7 @@ index f7fd05b36bae..5228bb719d94 100644
647
622
%0 = spirv.GroupNonUniformLogicalOr "Workgroup" "Reduce" %val : i32
648
623
return %0: i32
649
624
}
650
- @@ -544 ,7 +544 ,7 @@ func.func @group_non_uniform_logical_xor(%val: i1) -> i1 {
625
+ @@ -600 ,7 +600 ,7 @@ func.func @group_non_uniform_logical_xor(%val: i1) -> i1 {
651
626
// -----
652
627
653
628
func.func @group_non_uniform_logical_xor(%val: i32) -> i32 {
0 commit comments