1
- From 263d874c1ffcb5b36dca54ecdc148767aadcb7d7 Mon Sep 17 00:00:00 2001
1
+ From 94cc2bb6a778cad3b762244d6d78ecf2e19b5372 Mon Sep 17 00:00:00 2001
2
2
From: Md Abdullah Shahneous Bari <
[email protected] >
3
- Date: Thu, 24 Aug 2023 09:05:47 -0700
4
- Subject: [PATCH 1/6 ] Add support for VectorAnyINTEL capability
3
+ Date: Fri, 26 Apr 2024 20:20:28 +0000
4
+ Subject: [PATCH 1/7 ] Add- support- for- VectorAnyINTEL- capability
5
5
6
6
Allow vector of any lengths between [2-2^63-1].
7
7
VectorAnyINTEL capability (part of "SPV_INTEL_vector_compute" extension)
@@ -23,22 +23,26 @@ cases like this. Therefore, for cases like this, enable adding capability
23
23
requirement initially, then do the check for capability inferred extension.
24
24
25
25
- Add support for optionally skipping capability and extension requirement
26
+
26
27
---
27
28
.../mlir/Dialect/SPIRV/IR/SPIRVBase.td | 9 +-
28
29
mlir/include/mlir/IR/CommonTypeConstraints.td | 86 ++++++++++++
29
30
mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp | 7 +-
30
31
mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp | 24 +++-
31
32
.../SPIRV/Transforms/SPIRVConversion.cpp | 132 +++++++++++++++---
32
33
.../arith-to-spirv-unsupported.mlir | 4 +-
33
- .../ArithToSPIRV/arith-to-spirv.mlir | 35 +++++
34
+ .../ArithToSPIRV/arith-to-spirv.mlir | 34 +++++
34
35
.../FuncToSPIRV/types-to-spirv.mlir | 17 ++-
36
+ .../test/Dialect/SPIRV/IR/arithmetic-ops.mlir | 2 +-
35
37
mlir/test/Dialect/SPIRV/IR/bit-ops.mlir | 6 +-
36
38
mlir/test/Dialect/SPIRV/IR/gl-ops.mlir | 2 +-
39
+ mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir | 4 +-
37
40
mlir/test/Dialect/SPIRV/IR/logical-ops.mlir | 2 +-
38
- mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir | 36 ++---
41
+ .../Dialect/SPIRV/IR/non-uniform-ops.mlir | 12 +-
42
+ mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir | 34 ++---
39
43
mlir/test/Target/SPIRV/arithmetic-ops.mlir | 6 +-
40
44
mlir/test/Target/SPIRV/ocl-ops.mlir | 6 +
41
- 14 files changed, 312 insertions(+), 60 deletions(-)
45
+ 17 files changed, 319 insertions(+), 68 deletions(-)
42
46
43
47
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
44
48
index 6ec97e17c5dc..75e42c024553 100644
@@ -68,10 +72,10 @@ index 6ec97e17c5dc..75e42c024553 100644
68
72
class SPIRV_ScalarOrVectorOf<Type type> :
69
73
AnyTypeOf<[type, SPIRV_VectorOf<type>]>;
70
74
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
71
- index 03180a687523..e4f2d5562ed7 100644
75
+ index af4f13dc0936..28d49d9e91f0 100644
72
76
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
73
77
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
74
- @@ -604 ,6 +604 ,92 @@ class ScalableVectorOfRankAndLengthAndType<list<int> allowedRanks,
78
+ @@ -608 ,6 +608 ,92 @@ class ScalableVectorOfRankAndLengthAndType<list<int> allowedRanks,
75
79
ScalableVectorOfLength<allowedLengths>.summary,
76
80
"::mlir::VectorType">;
77
81
@@ -165,7 +169,7 @@ index 03180a687523..e4f2d5562ed7 100644
165
169
// Negative values for `n` index in reverse.
166
170
class ShapedTypeWithNthDimOfSize<int n, list<int> allowedSizes> : Type<
167
171
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
168
- index e914f46bdef6..0b139b79f051 100644
172
+ index 72488d6e5d0b..b38f20458d32 100644
169
173
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
170
174
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
171
175
@@ -187,9 +187,12 @@ static Type parseAndVerifyType(SPIRVDialect const &dialect,
@@ -226,7 +230,7 @@ index 3f25696aa5eb..2d64fea0dc26 100644
226
230
capabilities.push_back(ref);
227
231
}
228
232
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
229
- index 2b79c8022b8e..b778e4f4daf9 100644
233
+ index 4072608dc8f8..3fc675632970 100644
230
234
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
231
235
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
232
236
@@ -43,9 +43,13 @@ using namespace mlir;
@@ -385,7 +389,7 @@ index 2b79c8022b8e..b778e4f4daf9 100644
385
389
}
386
390
387
391
static Type
388
- @@ -1162 ,16 +1248 ,18 @@ bool SPIRVConversionTarget::isLegalOp(Operation *op) {
392
+ @@ -1163 ,16 +1249 ,18 @@ bool SPIRVConversionTarget::isLegalOp(Operation *op) {
389
393
SmallVector<ArrayRef<spirv::Extension>, 4> typeExtensions;
390
394
SmallVector<ArrayRef<spirv::Capability>, 8> typeCapabilities;
391
395
for (Type valueType : valueTypes) {
@@ -430,18 +434,10 @@ index 0d92a8e676d8..d61ace8d6876 100644
430
434
}
431
435
432
436
diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
433
- index ae47ae36ca51..0f5e79733574 100644
437
+ index ae47ae36ca51..644996fe0fa7 100644
434
438
--- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
435
439
+++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
436
- @@ -29,6 +29,7 @@ func.func @int32_scalar(%lhs: i32, %rhs: i32) {
437
-
438
- // CHECK-LABEL: @int32_scalar_srem
439
- // CHECK-SAME: (%[[LHS:.+]]: i32, %[[RHS:.+]]: i32)
440
- + %1 = arith.subi %arg0, %arg0: vector<5xi32>
441
- func.func @int32_scalar_srem(%lhs: i32, %rhs: i32) {
442
- // CHECK: %[[LABS:.+]] = spirv.GL.SAbs %[[LHS]] : i32
443
- // CHECK: %[[RABS:.+]] = spirv.GL.SAbs %[[RHS]] : i32
444
- @@ -1447,6 +1448,40 @@ func.func @ops_flags(%arg0: i64, %arg1: i64) {
440
+ @@ -1447,6 +1447,40 @@ func.func @ops_flags(%arg0: i64, %arg1: i64) {
445
441
%2 = arith.muli %arg0, %arg1 overflow<nsw, nuw> : i64
446
442
// CHECK: %{{.*}} = spirv.IMul %{{.*}}, %{{.*}} : i64
447
443
%3 = arith.muli %arg0, %arg1 overflow<nsw, nuw> : i64
@@ -510,6 +506,19 @@ index 82d750755ffe..6f364c5b0875 100644
510
506
511
507
} // end module
512
508
509
+ diff --git a/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir b/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir
510
+ index 2d0c86e08de5..61fc0b53ed26 100644
511
+ --- a/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir
512
+ +++ b/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir
513
+ @@ -283,7 +283,7 @@ func.func @dot(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> f16 {
514
+ // -----
515
+
516
+ func.func @dot(%arg0: vector<4xi32>, %arg1: vector<4xi32>) -> i32 {
517
+ - // expected-error @+1 {{'spirv.Dot' op operand #0 must be vector of 16/32/64-bit float values of length 2/3/4/8/16}}
518
+ + // expected-error @+1 {{op operand #0 must be vector of 16/32/64-bit float values of length 2-9223372036854775807, but got 'vector<4xi32>'}}
519
+ %0 = spirv.Dot %arg0, %arg1 : vector<4xi32> -> i32
520
+ return %0 : i32
521
+ }
513
522
diff --git a/mlir/test/Dialect/SPIRV/IR/bit-ops.mlir b/mlir/test/Dialect/SPIRV/IR/bit-ops.mlir
514
523
index f3f0ebf60f46..2994b00d582c 100644
515
524
--- a/mlir/test/Dialect/SPIRV/IR/bit-ops.mlir
@@ -554,6 +563,28 @@ index 3683e5b469b1..a95a6001fd20 100644
554
563
%2 = spirv.GL.Exp %arg0 : vector<5xf32>
555
564
return
556
565
}
566
+ diff --git a/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir b/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir
567
+ index 53a1015de75b..6970b8ec0628 100644
568
+ --- a/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir
569
+ +++ b/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir
570
+ @@ -21,7 +21,7 @@ spirv.func @f32_to_bf16_vec(%arg0 : vector<2xf32>) "None" {
571
+ // -----
572
+
573
+ spirv.func @f32_to_bf16_unsupported(%arg0 : f64) "None" {
574
+ - // expected-error @+1 {{operand #0 must be Float32 or vector of Float32 values of length 2/3/4/8/16, but got}}
575
+ + // expected-error @+1 {{op operand #0 must be Float32 or vector of Float32 values of length 2-9223372036854775807, but got 'f64'}}
576
+ %0 = spirv.INTEL.ConvertFToBF16 %arg0 : f64 to i16
577
+ spirv.Return
578
+ }
579
+ @@ -57,7 +57,7 @@ spirv.func @bf16_to_f32_vec(%arg0 : vector<2xi16>) "None" {
580
+ // -----
581
+
582
+ spirv.func @bf16_to_f32_unsupported(%arg0 : i16) "None" {
583
+ - // expected-error @+1 {{result #0 must be Float32 or vector of Float32 values of length 2/3/4/8/16, but got}}
584
+ + // expected-error @+1 {{op result #0 must be Float32 or vector of Float32 values of length 2-9223372036854775807, but got 'f16'}}
585
+ %0 = spirv.INTEL.ConvertBF16ToF %arg0 : i16 to f16
586
+ spirv.Return
587
+ }
557
588
diff --git a/mlir/test/Dialect/SPIRV/IR/logical-ops.mlir b/mlir/test/Dialect/SPIRV/IR/logical-ops.mlir
558
589
index 7dc0bd99f54b..5dd9901828cd 100644
559
590
--- a/mlir/test/Dialect/SPIRV/IR/logical-ops.mlir
@@ -567,8 +598,66 @@ index 7dc0bd99f54b..5dd9901828cd 100644
567
598
%0 = spirv.LogicalNot %arg0 : i32
568
599
return
569
600
}
601
+ diff --git a/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir b/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir
602
+ index f7fd05b36bae..5228bb719d94 100644
603
+ --- a/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir
604
+ +++ b/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir
605
+ @@ -439,7 +439,7 @@ func.func @group_non_uniform_bitwise_and(%val: i32) -> i32 {
606
+ // -----
607
+
608
+ func.func @group_non_uniform_bitwise_and(%val: i1) -> i1 {
609
+ - // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values of length 2/3/4/8/16, but got 'i1'}}
610
+ + // expected-error @+1 {{op operand #0 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values of length 2-9223372036854775807, but got 'i1'}}
611
+ %0 = spirv.GroupNonUniformBitwiseAnd "Workgroup" "Reduce" %val : i1
612
+ return %0: i1
613
+ }
614
+ @@ -460,7 +460,7 @@ func.func @group_non_uniform_bitwise_or(%val: i32) -> i32 {
615
+ // -----
616
+
617
+ func.func @group_non_uniform_bitwise_or(%val: i1) -> i1 {
618
+ - // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values of length 2/3/4/8/16, but got 'i1'}}
619
+ + // expected-error @+1 {{op operand #0 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values of length 2-9223372036854775807, but got 'i1'}}
620
+ %0 = spirv.GroupNonUniformBitwiseOr "Workgroup" "Reduce" %val : i1
621
+ return %0: i1
622
+ }
623
+ @@ -481,7 +481,7 @@ func.func @group_non_uniform_bitwise_xor(%val: i32) -> i32 {
624
+ // -----
625
+
626
+ func.func @group_non_uniform_bitwise_xor(%val: i1) -> i1 {
627
+ - // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values of length 2/3/4/8/16, but got 'i1'}}
628
+ + // expected-error @+1 {{op operand #0 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values of length 2-9223372036854775807, but got 'i1'}}
629
+ %0 = spirv.GroupNonUniformBitwiseXor "Workgroup" "Reduce" %val : i1
630
+ return %0: i1
631
+ }
632
+ @@ -502,7 +502,7 @@ func.func @group_non_uniform_logical_and(%val: i1) -> i1 {
633
+ // -----
634
+
635
+ func.func @group_non_uniform_logical_and(%val: i32) -> i32 {
636
+ - // expected-error @+1 {{operand #0 must be bool or vector of bool values of length 2/3/4/8/16, but got 'i32'}}
637
+ + // expected-error @+1 {{op operand #0 must be bool or vector of bool values of length 2-9223372036854775807, but got 'i32'}}
638
+ %0 = spirv.GroupNonUniformLogicalAnd "Workgroup" "Reduce" %val : i32
639
+ return %0: i32
640
+ }
641
+ @@ -523,7 +523,7 @@ func.func @group_non_uniform_logical_or(%val: i1) -> i1 {
642
+ // -----
643
+
644
+ func.func @group_non_uniform_logical_or(%val: i32) -> i32 {
645
+ - // expected-error @+1 {{operand #0 must be bool or vector of bool values of length 2/3/4/8/16, but got 'i32'}}
646
+ + // expected-error @+1 {{op operand #0 must be bool or vector of bool values of length 2-9223372036854775807, but got 'i32'}}
647
+ %0 = spirv.GroupNonUniformLogicalOr "Workgroup" "Reduce" %val : i32
648
+ return %0: i32
649
+ }
650
+ @@ -544,7 +544,7 @@ func.func @group_non_uniform_logical_xor(%val: i1) -> i1 {
651
+ // -----
652
+
653
+ func.func @group_non_uniform_logical_xor(%val: i32) -> i32 {
654
+ - // expected-error @+1 {{operand #0 must be bool or vector of bool values of length 2/3/4/8/16, but got 'i32'}}
655
+ + // expected-error @+1 {{op operand #0 must be bool or vector of bool values of length 2-9223372036854775807, but got 'i32'}}
656
+ %0 = spirv.GroupNonUniformLogicalXor "Workgroup" "Reduce" %val : i32
657
+ return %0: i32
658
+ }
570
659
diff --git a/mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir b/mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir
571
- index 81ba471d3f51..2dbebb2db98e 100644
660
+ index 81ba471d3f51..7a29abd44b34 100644
572
661
--- a/mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir
573
662
+++ b/mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir
574
663
@@ -27,7 +27,7 @@ func.func @exp(%arg0 : i32) -> () {
@@ -625,15 +714,7 @@ index 81ba471d3f51..2dbebb2db98e 100644
625
714
func.func @sabsi64(%arg0 : i64) -> () {
626
715
// CHECK: spirv.CL.s_abs {{%.*}} : i64
627
716
%2 = spirv.CL.s_abs %arg0 : i64
628
- @@ -137,21 +145,13 @@ func.func @sabsi8(%arg0 : i8) -> () {
629
- // -----
630
-
631
- func.func @sabs(%arg0 : f32) -> () {
632
- - // expected-error @+1 {{op operand #0 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values}}
633
- + // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}}
634
- %2 = spirv.CL.s_abs %arg0 : f32
635
- return
636
- }
717
+ @@ -144,14 +152,6 @@ func.func @sabs(%arg0 : f32) -> () {
637
718
638
719
// -----
639
720
0 commit comments