From 1361daa941948599b508163d203cb5130b2816c4 Mon Sep 17 00:00:00 2001 From: fairywreath Date: Fri, 23 May 2025 15:02:01 -0600 Subject: [PATCH 1/2] [mlir][spirv] Add GroupNonUniformVote instructions --- .../mlir/Dialect/SPIRV/IR/SPIRVBase.td | 42 ++--- .../Dialect/SPIRV/IR/SPIRVNonUniformOps.td | 160 ++++++++++++++++++ mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp | 33 ++++ .../Dialect/SPIRV/IR/non-uniform-ops.mlir | 73 ++++++++ mlir/test/Target/SPIRV/non-uniform-ops.mlir | 18 ++ 5 files changed, 307 insertions(+), 19 deletions(-) diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td index cd5d201c3d5da..8fd533db83d9a 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td @@ -4464,6 +4464,9 @@ def SPIRV_OC_OpGroupSMax : I32EnumAttrCase<"OpGroupSMax", 2 def SPIRV_OC_OpNoLine : I32EnumAttrCase<"OpNoLine", 317>; def SPIRV_OC_OpModuleProcessed : I32EnumAttrCase<"OpModuleProcessed", 330>; def SPIRV_OC_OpGroupNonUniformElect : I32EnumAttrCase<"OpGroupNonUniformElect", 333>; +def SPIRV_OC_OpGroupNonUniformAll : I32EnumAttrCase<"OpGroupNonUniformAll", 334>; +def SPIRV_OC_OpGroupNonUniformAny : I32EnumAttrCase<"OpGroupNonUniformAny", 335>; +def SPIRV_OC_OpGroupNonUniformAllEqual : I32EnumAttrCase<"OpGroupNonUniformAllEqual", 336>; def SPIRV_OC_OpGroupNonUniformBroadcast : I32EnumAttrCase<"OpGroupNonUniformBroadcast", 337>; def SPIRV_OC_OpGroupNonUniformBallot : I32EnumAttrCase<"OpGroupNonUniformBallot", 339>; def SPIRV_OC_OpGroupNonUniformBallotBitCount : I32EnumAttrCase<"OpGroupNonUniformBallotBitCount", 342>; @@ -4489,8 +4492,8 @@ def SPIRV_OC_OpGroupNonUniformBitwiseXor : I32EnumAttrCase<"OpGroupNonUnifo def SPIRV_OC_OpGroupNonUniformLogicalAnd : I32EnumAttrCase<"OpGroupNonUniformLogicalAnd", 362>; def SPIRV_OC_OpGroupNonUniformLogicalOr : I32EnumAttrCase<"OpGroupNonUniformLogicalOr", 363>; def SPIRV_OC_OpGroupNonUniformLogicalXor : I32EnumAttrCase<"OpGroupNonUniformLogicalXor", 364>; -def SPIRV_OC_OpGroupNonUniformRotateKHR : I32EnumAttrCase<"OpGroupNonUniformRotateKHR", 4431>; def SPIRV_OC_OpSubgroupBallotKHR : I32EnumAttrCase<"OpSubgroupBallotKHR", 4421>; +def SPIRV_OC_OpGroupNonUniformRotateKHR : I32EnumAttrCase<"OpGroupNonUniformRotateKHR", 4431>; def SPIRV_OC_OpSDot : I32EnumAttrCase<"OpSDot", 4450>; def SPIRV_OC_OpUDot : I32EnumAttrCase<"OpUDot", 4451>; def SPIRV_OC_OpSUDot : I32EnumAttrCase<"OpSUDot", 4452>; @@ -4581,11 +4584,13 @@ def SPIRV_OpcodeAttr : SPIRV_OC_OpAtomicAnd, SPIRV_OC_OpAtomicOr, SPIRV_OC_OpAtomicXor, SPIRV_OC_OpPhi, SPIRV_OC_OpLoopMerge, SPIRV_OC_OpSelectionMerge, SPIRV_OC_OpLabel, SPIRV_OC_OpBranch, SPIRV_OC_OpBranchConditional, - SPIRV_OC_OpKill, SPIRV_OC_OpReturn, SPIRV_OC_OpReturnValue, SPIRV_OC_OpUnreachable, - SPIRV_OC_OpGroupBroadcast, SPIRV_OC_OpGroupIAdd, SPIRV_OC_OpGroupFAdd, - SPIRV_OC_OpGroupFMin, SPIRV_OC_OpGroupUMin, SPIRV_OC_OpGroupSMin, - SPIRV_OC_OpGroupFMax, SPIRV_OC_OpGroupUMax, SPIRV_OC_OpGroupSMax, - SPIRV_OC_OpNoLine, SPIRV_OC_OpModuleProcessed, SPIRV_OC_OpGroupNonUniformElect, + SPIRV_OC_OpKill, SPIRV_OC_OpReturn, SPIRV_OC_OpReturnValue, + SPIRV_OC_OpUnreachable, SPIRV_OC_OpGroupBroadcast, SPIRV_OC_OpGroupIAdd, + SPIRV_OC_OpGroupFAdd, SPIRV_OC_OpGroupFMin, SPIRV_OC_OpGroupUMin, + SPIRV_OC_OpGroupSMin, SPIRV_OC_OpGroupFMax, SPIRV_OC_OpGroupUMax, + SPIRV_OC_OpGroupSMax, SPIRV_OC_OpNoLine, SPIRV_OC_OpModuleProcessed, + SPIRV_OC_OpGroupNonUniformElect, SPIRV_OC_OpGroupNonUniformAll, + SPIRV_OC_OpGroupNonUniformAny, SPIRV_OC_OpGroupNonUniformAllEqual, SPIRV_OC_OpGroupNonUniformBroadcast, SPIRV_OC_OpGroupNonUniformBallot, SPIRV_OC_OpGroupNonUniformBallotBitCount, SPIRV_OC_OpGroupNonUniformBallotFindLSB, @@ -4599,19 +4604,18 @@ def SPIRV_OpcodeAttr : SPIRV_OC_OpGroupNonUniformFMax, SPIRV_OC_OpGroupNonUniformBitwiseAnd, SPIRV_OC_OpGroupNonUniformBitwiseOr, SPIRV_OC_OpGroupNonUniformBitwiseXor, SPIRV_OC_OpGroupNonUniformLogicalAnd, SPIRV_OC_OpGroupNonUniformLogicalOr, - SPIRV_OC_OpGroupNonUniformLogicalXor, SPIRV_OC_OpGroupNonUniformRotateKHR, - SPIRV_OC_OpSubgroupBallotKHR, - SPIRV_OC_OpSDot, SPIRV_OC_OpUDot, SPIRV_OC_OpSUDot, SPIRV_OC_OpSDotAccSat, - SPIRV_OC_OpUDotAccSat, SPIRV_OC_OpSUDotAccSat, - SPIRV_OC_OpTypeCooperativeMatrixKHR, SPIRV_OC_OpCooperativeMatrixLoadKHR, - SPIRV_OC_OpCooperativeMatrixStoreKHR, SPIRV_OC_OpCooperativeMatrixMulAddKHR, - SPIRV_OC_OpCooperativeMatrixLengthKHR, SPIRV_OC_OpEmitMeshTasksEXT, - SPIRV_OC_OpSetMeshOutputsEXT, SPIRV_OC_OpSubgroupBlockReadINTEL, - SPIRV_OC_OpSubgroupBlockWriteINTEL, SPIRV_OC_OpAssumeTrueKHR, - SPIRV_OC_OpAtomicFAddEXT, SPIRV_OC_OpConvertFToBF16INTEL, - SPIRV_OC_OpConvertBF16ToFINTEL, SPIRV_OC_OpControlBarrierArriveINTEL, - SPIRV_OC_OpControlBarrierWaitINTEL, SPIRV_OC_OpGroupIMulKHR, - SPIRV_OC_OpGroupFMulKHR + SPIRV_OC_OpGroupNonUniformLogicalXor, SPIRV_OC_OpSubgroupBallotKHR, + SPIRV_OC_OpGroupNonUniformRotateKHR, SPIRV_OC_OpSDot, SPIRV_OC_OpUDot, + SPIRV_OC_OpSUDot, SPIRV_OC_OpSDotAccSat, SPIRV_OC_OpUDotAccSat, + SPIRV_OC_OpSUDotAccSat, SPIRV_OC_OpTypeCooperativeMatrixKHR, + SPIRV_OC_OpCooperativeMatrixLoadKHR, SPIRV_OC_OpCooperativeMatrixStoreKHR, + SPIRV_OC_OpCooperativeMatrixMulAddKHR, SPIRV_OC_OpCooperativeMatrixLengthKHR, + SPIRV_OC_OpEmitMeshTasksEXT, SPIRV_OC_OpSetMeshOutputsEXT, + SPIRV_OC_OpSubgroupBlockReadINTEL, SPIRV_OC_OpSubgroupBlockWriteINTEL, + SPIRV_OC_OpAssumeTrueKHR, SPIRV_OC_OpAtomicFAddEXT, + SPIRV_OC_OpConvertFToBF16INTEL, SPIRV_OC_OpConvertBF16ToFINTEL, + SPIRV_OC_OpControlBarrierArriveINTEL, SPIRV_OC_OpControlBarrierWaitINTEL, + SPIRV_OC_OpGroupIMulKHR, SPIRV_OC_OpGroupFMulKHR ]>; // End opcode section. Generated from SPIR-V spec; DO NOT MODIFY! diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td index 3fdaff2470cba..db337f577b37e 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td @@ -1435,4 +1435,164 @@ def SPIRV_GroupNonUniformRotateKHROp : SPIRV_Op<"GroupNonUniformRotateKHR", [ // ----- +def SPIRV_GroupNonUniformAllOp : SPIRV_Op<"GroupNonUniformAll", []> { + let summary = [{ + Evaluates a predicate for all tangled invocations within the Execution + scope, resulting in true if predicate evaluates to true for all tangled + invocations within the Execution scope, otherwise the result is false. + }]; + + let description = [{ + Result Type must be a Boolean type. + + Execution is the scope defining the scope restricted tangle affected by + this command. It must be Subgroup. + + Predicate must be a Boolean type. + + An invocation will not execute a dynamic instance of this instruction + (X') until all invocations in its scope restricted tangle have executed + all dynamic instances that are program-ordered before X'. + + + + #### Example: + + ```mlir + %predicate = ... : i1 + %0 = spirv.GroupNonUniformAll "Subgroup" %predicate : i1 + ``` + }]; + + let availability = [ + MinVersion, + MaxVersion, + Extension<[]>, + Capability<[SPIRV_C_GroupNonUniformVote]> + ]; + + let arguments = (ins + SPIRV_ScopeAttr:$execution_scope, + SPIRV_Bool:$predicate + ); + + let results = (outs + SPIRV_Bool:$result + ); + + let assemblyFormat = [{ + $execution_scope $predicate attr-dict `:` type($result) + }]; +} + +// ----- + +def SPIRV_GroupNonUniformAnyOp : SPIRV_Op<"GroupNonUniformAny", []> { + let summary = [{ + Evaluates a predicate for all tangled invocations within the Execution + scope, resulting in true if predicate evaluates to true for any tangled + invocations within the Execution scope, otherwise the result is false. + }]; + + let description = [{ + Result Type must be a Boolean type. + + Execution is the scope defining the scope restricted tangle affected by + this command. It must be Subgroup. + + Predicate must be a Boolean type. + + An invocation will not execute a dynamic instance of this instruction + (X') until all invocations in its scope restricted tangle have executed + all dynamic instances that are program-ordered before X'. + + + + #### Example: + + ```mlir + %predicate = ... : i1 + %0 = spirv.GroupNonUniformAny "Subgroup" %predicate : i1 + ``` + }]; + + let availability = [ + MinVersion, + MaxVersion, + Extension<[]>, + Capability<[SPIRV_C_GroupNonUniformVote]> + ]; + + let arguments = (ins + SPIRV_ScopeAttr:$execution_scope, + SPIRV_Bool:$predicate + ); + + let results = (outs + SPIRV_Bool:$result + ); + + let assemblyFormat = [{ + $execution_scope $predicate attr-dict `:` type($result) + }]; +} + +// ----- + +def SPIRV_GroupNonUniformAllEqualOp : SPIRV_Op<"GroupNonUniformAllEqual", []> { + let summary = [{ + Evaluates a value for all tangled invocations within the Execution + scope. The result is true if Value is equal for all tangled invocations + within the Execution scope. Otherwise, the result is false. + }]; + + let description = [{ + Result Type must be a Boolean type. + + Execution is the scope defining the scope restricted tangle affected by + this command. It must be Subgroup. + + Value must be a scalar or vector of floating-point type, integer type, + or Boolean type. The compare operation is based on this type, and if it + is a floating-point type, an ordered-and-equal compare is used. + + An invocation will not execute a dynamic instance of this instruction + (X') until all invocations in its scope restricted tangle have executed + all dynamic instances that are program-ordered before X'. + + + + #### Example: + + ```mlir + %scalar_value = ... : f32 + %vector_value = ... : vector<4xf32> + %0 = spirv.GroupNonUniformAllEqual %scalar_value : f32, i1 + %1 = spirv.GroupNonUniformAllEqual %vector_value : vector<4xf32>, i1 + ``` + }]; + + let availability = [ + MinVersion, + MaxVersion, + Extension<[]>, + Capability<[SPIRV_C_GroupNonUniformVote]> + ]; + + let arguments = (ins + SPIRV_ScopeAttr:$execution_scope, + AnyTypeOf<[SPIRV_ScalarOrVectorOf, SPIRV_ScalarOrVectorOf, SPIRV_ScalarOrVectorOf]>:$value + ); + + let results = (outs + SPIRV_Bool:$result + ); + + let assemblyFormat = [{ + $execution_scope $value attr-dict `:` type($value) `,` type($result) + }]; +} + +// ----- + #endif // MLIR_DIALECT_SPIRV_IR_NON_UNIFORM_OPS diff --git a/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp b/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp index 461d037134dae..aba876c1c80f4 100644 --- a/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp @@ -327,6 +327,39 @@ LogicalResult GroupNonUniformRotateKHROp::verify() { return success(); } +//===----------------------------------------------------------------------===// +// spirv.GroupNonUniformAllOp +//===----------------------------------------------------------------------===// + +LogicalResult GroupNonUniformAllOp::verify() { + if (getExecutionScope() != spirv::Scope::Subgroup) + return emitOpError("execution scope must be 'Subgroup'"); + + return success(); +} + +//===----------------------------------------------------------------------===// +// spirv.GroupNonUniformAllOp +//===----------------------------------------------------------------------===// + +LogicalResult GroupNonUniformAnyOp::verify() { + if (getExecutionScope() != spirv::Scope::Subgroup) + return emitOpError("execution scope must be 'Subgroup'"); + + return success(); +} + +//===----------------------------------------------------------------------===// +// spirv.GroupNonUniformAllEqualOp +//===----------------------------------------------------------------------===// + +LogicalResult GroupNonUniformAllEqualOp::verify() { + if (getExecutionScope() != spirv::Scope::Subgroup) + return emitOpError("execution scope must be 'Subgroup'"); + + return success(); +} + //===----------------------------------------------------------------------===// // Group op verification //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir b/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir index 6990f2b3751f5..d7c840dc6a8ef 100644 --- a/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir @@ -671,3 +671,76 @@ func.func @group_non_uniform_rotate_khr(%val: f32, %delta: i32) -> f32 { %0 = spirv.GroupNonUniformRotateKHR %val, %delta, cluster_size(%five) : f32, i32, i32 -> f32 return %0: f32 } + +// ----- + +//===----------------------------------------------------------------------===// +// spirv.GroupNonUniformAll +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: @group_non_uniform_all +func.func @group_non_uniform_all(%predicate: i1) -> i1 { + // CHECK: %{{.+}} = spirv.GroupNonUniformAll %{{.+}} : i1 + %0 = spirv.GroupNonUniformAll %predicate : i1 + return %0: i1 +} + +// ----- + +func.func @group_non_uniform_all(%predicate: i1) -> i1 { + // expected-error @+1 {{execution scope must be 'Subgroup'}} + %0 = spirv.GroupNonUniformAll %predicate : i1 + return %0: i1 +} + +// ----- + +//===----------------------------------------------------------------------===// +// spirv.GroupNonUniformAny +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: @group_non_uniform_any +func.func @group_non_uniform_any(%predicate: i1) -> i1 { + // CHECK: %{{.+}} = spirv.GroupNonUniformAny %{{.+}} : i1 + %0 = spirv.GroupNonUniformAny %predicate : i1 + return %0: i1 +} + +// ----- + +func.func @group_non_uniform_any(%predicate: i1) -> i1 { + // expected-error @+1 {{execution scope must be 'Subgroup'}} + %0 = spirv.GroupNonUniformAny %predicate : i1 + return %0: i1 +} + +// ----- + +//===----------------------------------------------------------------------===// +// spirv.GroupNonUniformAllEqual +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: @group_non_uniform_all_equal +func.func @group_non_uniform_all_equal(%value: f32) -> i1 { + // CHECK: %{{.+}} = spirv.GroupNonUniformAllEqual %{{.+}} : f32, i1 + %0 = spirv.GroupNonUniformAllEqual %value : f32, i1 + return %0: i1 +} + +// ----- + +// CHECK-LABEL: @group_non_uniform_all_equal +func.func @group_non_uniform_all_equal(%value: vector<4xi32>) -> i1 { + // CHECK: %{{.+}} = spirv.GroupNonUniformAllEqual %{{.+}} : vector<4xi32>, i1 + %0 = spirv.GroupNonUniformAllEqual %value : vector<4xi32>, i1 + return %0: i1 +} + + +// ----- + +func.func @group_non_uniform_all_equal(%value: f32) -> i1 { + // expected-error @+1 {{execution scope must be 'Subgroup'}} + %0 = spirv.GroupNonUniformAllEqual %value : f32, i1 + return %0: i1 +} diff --git a/mlir/test/Target/SPIRV/non-uniform-ops.mlir b/mlir/test/Target/SPIRV/non-uniform-ops.mlir index 3e78eaf8b03ef..f29ebd86a2e03 100644 --- a/mlir/test/Target/SPIRV/non-uniform-ops.mlir +++ b/mlir/test/Target/SPIRV/non-uniform-ops.mlir @@ -124,4 +124,22 @@ spirv.module Logical GLSL450 requires #spirv.vce { %0 = spirv.GroupNonUniformShuffleXor %val, %id : f32, i32 spirv.ReturnValue %0: f32 } + + spirv.func @group_non_uniform_all(%pred: i1) -> i1 "None" { + // CHECK: %{{.+}} = spirv.GroupNonUniformAll %{{.+}} : i1 + %0 = spirv.GroupNonUniformAll %pred : i1 + spirv.ReturnValue %0: i1 + } + + spirv.func @group_non_uniform_any(%pred: i1) -> i1 "None" { + // CHECK: %{{.+}} = spirv.GroupNonUniformAny %{{.+}} : i1 + %0 = spirv.GroupNonUniformAny %pred : i1 + spirv.ReturnValue %0: i1 + } + + spirv.func @group_non_uniform_all_equal(%val: vector<4xi32>) -> i1 "None" { + // CHECK: %{{.+}} = spirv.GroupNonUniformAllEqual %{{.+}} : vector<4xi32>, i1 + %0 = spirv.GroupNonUniformAllEqual %val : vector<4xi32>, i1 + spirv.ReturnValue %0: i1 + } } From 573c744863447bfc9b3aa4de42855c95b5be4511 Mon Sep 17 00:00:00 2001 From: fairywreath Date: Sun, 25 May 2025 02:15:24 -0400 Subject: [PATCH 2/2] Use existing attribute constraint and remove custom verifier functions --- .../Dialect/SPIRV/IR/SPIRVNonUniformOps.td | 19 +++++++++-- mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp | 33 ------------------- .../Dialect/SPIRV/IR/non-uniform-ops.mlir | 6 ++-- 3 files changed, 19 insertions(+), 39 deletions(-) diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td index db337f577b37e..7e2ab64afc6d0 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td @@ -1435,7 +1435,9 @@ def SPIRV_GroupNonUniformRotateKHROp : SPIRV_Op<"GroupNonUniformRotateKHR", [ // ----- -def SPIRV_GroupNonUniformAllOp : SPIRV_Op<"GroupNonUniformAll", []> { +def SPIRV_GroupNonUniformAllOp : SPIRV_Op<"GroupNonUniformAll", [ + SPIRV_ExecutionScopeAttrIs<"execution_scope", "Subgroup"> +]> { let summary = [{ Evaluates a predicate for all tangled invocations within the Execution scope, resulting in true if predicate evaluates to true for all tangled @@ -1480,6 +1482,8 @@ def SPIRV_GroupNonUniformAllOp : SPIRV_Op<"GroupNonUniformAll", []> { SPIRV_Bool:$result ); + let hasVerifier = 0; + let assemblyFormat = [{ $execution_scope $predicate attr-dict `:` type($result) }]; @@ -1487,7 +1491,9 @@ def SPIRV_GroupNonUniformAllOp : SPIRV_Op<"GroupNonUniformAll", []> { // ----- -def SPIRV_GroupNonUniformAnyOp : SPIRV_Op<"GroupNonUniformAny", []> { +def SPIRV_GroupNonUniformAnyOp : SPIRV_Op<"GroupNonUniformAny", [ + SPIRV_ExecutionScopeAttrIs<"execution_scope", "Subgroup"> +]> { let summary = [{ Evaluates a predicate for all tangled invocations within the Execution scope, resulting in true if predicate evaluates to true for any tangled @@ -1532,6 +1538,8 @@ def SPIRV_GroupNonUniformAnyOp : SPIRV_Op<"GroupNonUniformAny", []> { SPIRV_Bool:$result ); + let hasVerifier = 0; + let assemblyFormat = [{ $execution_scope $predicate attr-dict `:` type($result) }]; @@ -1539,7 +1547,9 @@ def SPIRV_GroupNonUniformAnyOp : SPIRV_Op<"GroupNonUniformAny", []> { // ----- -def SPIRV_GroupNonUniformAllEqualOp : SPIRV_Op<"GroupNonUniformAllEqual", []> { +def SPIRV_GroupNonUniformAllEqualOp : SPIRV_Op<"GroupNonUniformAllEqual", [ + SPIRV_ExecutionScopeAttrIs<"execution_scope", "Subgroup"> +]> { let summary = [{ Evaluates a value for all tangled invocations within the Execution scope. The result is true if Value is equal for all tangled invocations @@ -1588,6 +1598,9 @@ def SPIRV_GroupNonUniformAllEqualOp : SPIRV_Op<"GroupNonUniformAllEqual", []> { SPIRV_Bool:$result ); + + let hasVerifier = 0; + let assemblyFormat = [{ $execution_scope $value attr-dict `:` type($value) `,` type($result) }]; diff --git a/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp b/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp index aba876c1c80f4..461d037134dae 100644 --- a/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp @@ -327,39 +327,6 @@ LogicalResult GroupNonUniformRotateKHROp::verify() { return success(); } -//===----------------------------------------------------------------------===// -// spirv.GroupNonUniformAllOp -//===----------------------------------------------------------------------===// - -LogicalResult GroupNonUniformAllOp::verify() { - if (getExecutionScope() != spirv::Scope::Subgroup) - return emitOpError("execution scope must be 'Subgroup'"); - - return success(); -} - -//===----------------------------------------------------------------------===// -// spirv.GroupNonUniformAllOp -//===----------------------------------------------------------------------===// - -LogicalResult GroupNonUniformAnyOp::verify() { - if (getExecutionScope() != spirv::Scope::Subgroup) - return emitOpError("execution scope must be 'Subgroup'"); - - return success(); -} - -//===----------------------------------------------------------------------===// -// spirv.GroupNonUniformAllEqualOp -//===----------------------------------------------------------------------===// - -LogicalResult GroupNonUniformAllEqualOp::verify() { - if (getExecutionScope() != spirv::Scope::Subgroup) - return emitOpError("execution scope must be 'Subgroup'"); - - return success(); -} - //===----------------------------------------------------------------------===// // Group op verification //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir b/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir index d7c840dc6a8ef..5f56de6ad1fa9 100644 --- a/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir @@ -688,7 +688,7 @@ func.func @group_non_uniform_all(%predicate: i1) -> i1 { // ----- func.func @group_non_uniform_all(%predicate: i1) -> i1 { - // expected-error @+1 {{execution scope must be 'Subgroup'}} + // expected-error @+1 {{execution_scope must be Scope of value Subgroup}} %0 = spirv.GroupNonUniformAll %predicate : i1 return %0: i1 } @@ -709,7 +709,7 @@ func.func @group_non_uniform_any(%predicate: i1) -> i1 { // ----- func.func @group_non_uniform_any(%predicate: i1) -> i1 { - // expected-error @+1 {{execution scope must be 'Subgroup'}} + // expected-error @+1 {{execution_scope must be Scope of value Subgroup}} %0 = spirv.GroupNonUniformAny %predicate : i1 return %0: i1 } @@ -740,7 +740,7 @@ func.func @group_non_uniform_all_equal(%value: vector<4xi32>) -> i1 { // ----- func.func @group_non_uniform_all_equal(%value: f32) -> i1 { - // expected-error @+1 {{execution scope must be 'Subgroup'}} + // expected-error @+1 {{execution_scope must be Scope of value Subgroup}} %0 = spirv.GroupNonUniformAllEqual %value : f32, i1 return %0: i1 }