diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td index d5359da2a590e..cd5d201c3d5da 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td @@ -4489,6 +4489,7 @@ def SPIRV_OC_OpGroupNonUniformBitwiseXor : I32EnumAttrCase<"OpGroupNonUnifo def SPIRV_OC_OpGroupNonUniformLogicalAnd : I32EnumAttrCase<"OpGroupNonUniformLogicalAnd", 362>; def SPIRV_OC_OpGroupNonUniformLogicalOr : I32EnumAttrCase<"OpGroupNonUniformLogicalOr", 363>; def SPIRV_OC_OpGroupNonUniformLogicalXor : I32EnumAttrCase<"OpGroupNonUniformLogicalXor", 364>; +def SPIRV_OC_OpGroupNonUniformRotateKHR : I32EnumAttrCase<"OpGroupNonUniformRotateKHR", 4431>; def SPIRV_OC_OpSubgroupBallotKHR : I32EnumAttrCase<"OpSubgroupBallotKHR", 4421>; def SPIRV_OC_OpSDot : I32EnumAttrCase<"OpSDot", 4450>; def SPIRV_OC_OpUDot : I32EnumAttrCase<"OpUDot", 4451>; @@ -4598,7 +4599,8 @@ def SPIRV_OpcodeAttr : SPIRV_OC_OpGroupNonUniformFMax, SPIRV_OC_OpGroupNonUniformBitwiseAnd, SPIRV_OC_OpGroupNonUniformBitwiseOr, SPIRV_OC_OpGroupNonUniformBitwiseXor, SPIRV_OC_OpGroupNonUniformLogicalAnd, SPIRV_OC_OpGroupNonUniformLogicalOr, - SPIRV_OC_OpGroupNonUniformLogicalXor, SPIRV_OC_OpSubgroupBallotKHR, + SPIRV_OC_OpGroupNonUniformLogicalXor, SPIRV_OC_OpGroupNonUniformRotateKHR, + SPIRV_OC_OpSubgroupBallotKHR, SPIRV_OC_OpSDot, SPIRV_OC_OpUDot, SPIRV_OC_OpSUDot, SPIRV_OC_OpSDotAccSat, SPIRV_OC_OpUDotAccSat, SPIRV_OC_OpSUDotAccSat, SPIRV_OC_OpTypeCooperativeMatrixKHR, SPIRV_OC_OpCooperativeMatrixLoadKHR, diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td index 98e435c18d3d7..2dd3dbd28d436 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td @@ -1361,4 +1361,78 @@ def SPIRV_GroupNonUniformBallotBitCountOp : SPIRV_Op<"GroupNonUniformBallotBitCo // ----- +def SPIRV_GroupNonUniformRotateKHROp : SPIRV_Op<"GroupNonUniformRotateKHR", [ + Pure, AllTypesMatch<["value", "result"]>]> { + let summary = [{ + Rotate values across invocations within a subgroup. + }]; + + let description = [{ + Return the Value of the invocation whose id within the group is calculated + as follows: + + LocalId = SubgroupLocalInvocationId if Execution is Subgroup or + LocalInvocationId if Execution is Workgroup + RotationGroupSize = ClusterSize when ClusterSize is present, otherwise + RotationGroupSize = SubgroupMaxSize if the Kernel capability is declared + and SubgroupSize if not. + Invocation ID = ( (LocalId + Delta) & (RotationGroupSize - 1) ) + + (LocalId & ~(RotationGroupSize - 1)) + + Result Type must be a scalar or vector of floating-point type, integer + type, or Boolean type. + + Execution is a Scope. It must be either Workgroup or Subgroup. + + The type of Value must be the same as Result Type. + + Delta must be a scalar of integer type, whose Signedness operand is 0. + Delta must be dynamically uniform within Execution. + + Delta is treated as unsigned and the resulting value is undefined if the + selected lane is inactive. + + ClusterSize is the size of cluster to use. ClusterSize must be a scalar of + integer type, whose Signedness operand is 0. ClusterSize must come from a + constant instruction. Behavior is undefined unless ClusterSize is at least + 1 and a power of 2. If ClusterSize is greater than the declared + SubGroupSize, executing this instruction results in undefined behavior. + + + + #### Example: + + ```mlir + %four = spirv.Constant 4 : i32 + %0 = spirv.GroupNonUniformRotateKHR , %value, %delta : f32, i32 -> f32 + %1 = spirv.GroupNonUniformRotateKHR , %value, %delta, + clustersize(%four) : f32, i32, i32 -> f32 + ``` + }]; + + let availability = [ + MinVersion, + MaxVersion, + Extension<[]>, + Capability<[SPIRV_C_GroupNonUniformRotateKHR]> + ]; + + let arguments = (ins + SPIRV_ScopeAttr:$execution_scope, + AnyTypeOf<[SPIRV_ScalarOrVectorOf, SPIRV_ScalarOrVectorOf, SPIRV_ScalarOrVectorOf]>:$value, + SPIRV_SignlessOrUnsignedInt:$delta, + Optional:$cluster_size + ); + + let results = (outs + AnyTypeOf<[SPIRV_ScalarOrVectorOf, SPIRV_ScalarOrVectorOf, SPIRV_ScalarOrVectorOf]>:$result + ); + + let assemblyFormat = [{ + $execution_scope `,` $value `,` $delta (`,` `cluster_size` `(` $cluster_size^ `)`)? attr-dict `:` type($value) `,` type($delta) (`,` type($cluster_size)^)? `->` type(results) + }]; +} + +// ----- + #endif // MLIR_DIALECT_SPIRV_IR_NON_UNIFORM_OPS diff --git a/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp b/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp index 8aeafda0eb755..461d037134dae 100644 --- a/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp @@ -304,6 +304,29 @@ LogicalResult GroupNonUniformLogicalXorOp::verify() { return verifyGroupNonUniformArithmeticOp(*this); } +//===----------------------------------------------------------------------===// +// spirv.GroupNonUniformRotateKHR +//===----------------------------------------------------------------------===// + +LogicalResult GroupNonUniformRotateKHROp::verify() { + spirv::Scope scope = getExecutionScope(); + if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup) + return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'"); + + if (Value clusterSizeVal = getClusterSize()) { + mlir::Operation *defOp = clusterSizeVal.getDefiningOp(); + int32_t clusterSize = 0; + + if (failed(extractValueFromConstOp(defOp, clusterSize))) + return emitOpError("cluster size operand must come from a constant op"); + + if (!llvm::isPowerOf2_32(clusterSize)) + return emitOpError("cluster size operand must be a power of two"); + } + + return success(); +} + //===----------------------------------------------------------------------===// // Group op verification //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir b/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir index 60ae1584d29fb..bf383d3837b6e 100644 --- a/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir @@ -604,3 +604,70 @@ func.func @group_non_uniform_logical_xor(%val: i32) -> i32 { %0 = spirv.GroupNonUniformLogicalXor %val : i32 -> i32 return %0: i32 } + +// ----- + +//===----------------------------------------------------------------------===// +// spirv.GroupNonUniformRotateKHR +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: @group_non_uniform_rotate_khr +func.func @group_non_uniform_rotate_khr(%val: f32, %delta: i32) -> f32 { + // CHECK: %{{.+}} = spirv.GroupNonUniformRotateKHR , %{{.+}} : f32, i32 -> f32 + %0 = spirv.GroupNonUniformRotateKHR , %val, %delta : f32, i32 -> f32 + return %0: f32 +} + +// ----- + +// CHECK-LABEL: @group_non_uniform_rotate_khr +func.func @group_non_uniform_rotate_khr(%val: f32, %delta: i32) -> f32 { + // CHECK: %{{.+}} = spirv.GroupNonUniformRotateKHR , %{{.+}} : f32, i32, i32 -> f32 + %four = spirv.Constant 4 : i32 + %0 = spirv.GroupNonUniformRotateKHR , %val, %delta, cluster_size(%four) : f32, i32, i32 -> f32 + return %0: f32 +} + +// ----- + +func.func @group_non_uniform_rotate_khr(%val: f32, %delta: i32) -> f32 { + %four = spirv.Constant 4 : i32 + // expected-error @+1 {{execution scope must be 'Workgroup' or 'Subgroup'}} + %0 = spirv.GroupNonUniformRotateKHR , %val, %delta, cluster_size(%four) : f32, i32, i32 -> f32 + return %0: f32 +} + +// ----- + +func.func @group_non_uniform_rotate_khr(%val: f32, %delta: si32) -> f32 { + %four = spirv.Constant 4 : i32 + // expected-error @+1 {{op operand #1 must be 8/16/32/64-bit signless/unsigned integer, but got 'si32'}} + %0 = spirv.GroupNonUniformRotateKHR , %val, %delta, cluster_size(%four) : f32, si32, i32 -> f32 + return %0: f32 +} + +// ----- + +func.func @group_non_uniform_rotate_khr(%val: f32, %delta: i32) -> f32 { + %four = spirv.Constant 4 : si32 + // expected-error @+1 {{op operand #2 must be 8/16/32/64-bit signless/unsigned integer, but got 'si32'}} + %0 = spirv.GroupNonUniformRotateKHR , %val, %delta, cluster_size(%four) : f32, i32, si32 -> f32 + return %0: f32 +} + +// ----- + +func.func @group_non_uniform_rotate_khr(%val: f32, %delta: i32, %four: i32) -> f32 { + // expected-error @+1 {{cluster size operand must come from a constant op}} + %0 = spirv.GroupNonUniformRotateKHR , %val, %delta, cluster_size(%four) : f32, i32, i32 -> f32 + return %0: f32 +} + +// ----- + +func.func @group_non_uniform_rotate_khr(%val: f32, %delta: i32) -> f32 { + %five = spirv.Constant 5 : i32 + // expected-error @+1 {{cluster size operand must be a power of two}} + %0 = spirv.GroupNonUniformRotateKHR , %val, %delta, cluster_size(%five) : f32, i32, i32 -> f32 + return %0: f32 +}