diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLOps.td index 3fcfb086f9662..1cdfa02f81787 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLOps.td @@ -1029,6 +1029,122 @@ def SPIRV_GLFMixOp : let hasVerifier = 0; } +// ----- + +def SPIRV_GLDistanceOp : SPIRV_GLOp<"Distance", 67, [ + Pure, + AllTypesMatch<["p0", "p1"]>, + TypesMatchWith<"result type must match operand element type", + "p0", "result", + "::mlir::getElementTypeOrSelf($_self)"> + ]> { + let summary = "Return distance between two points"; + + let description = [{ + Result is the distance between p0 and p1, i.e., length(p0 - p1). + + The operands must all be a scalar or vector whose component type is floating-point. + + Result Type must be a scalar of the same type as the component type of the operands. + + #### Example: + + ```mlir + %2 = spirv.GL.Distance %0, %1 : vector<3xf32>, vector<3xf32> -> f32 + ``` + }]; + + let arguments = (ins + SPIRV_ScalarOrVectorOf:$p0, + SPIRV_ScalarOrVectorOf:$p1 + ); + + let results = (outs + SPIRV_Float:$result + ); + + let assemblyFormat = [{ + operands attr-dict `:` type($p0) `,` type($p1) `->` type($result) + }]; + + let hasVerifier = 0; +} + +// ----- + +def SPIRV_GLCrossOp : SPIRV_GLBinaryArithmeticOp<"Cross", 68, SPIRV_Float> { + let summary = "Return the cross product of two 3-component vectors"; + + let description = [{ + Result is the cross product of x and y, i.e., the resulting components are, in order: + + x[1] * y[2] - y[1] * x[2] + + x[2] * y[0] - y[2] * x[0] + + x[0] * y[1] - y[0] * x[1] + + All the operands must be vectors of 3 components of a floating-point type. + + Result Type and the type of all operands must be the same type. + + #### Example: + + ```mlir + %2 = spirv.GL.Cross %0, %1 : vector<3xf32> + %3 = spirv.GL.Cross %0, %1 : vector<3xf16> + ``` + }]; +} + +// ----- + +def SPIRV_GLNormalizeOp : SPIRV_GLUnaryArithmeticOp<"Normalize", 69, SPIRV_Float> { + let summary = "Normalizes a vector operand"; + + let description = [{ + Result is the vector in the same direction as x but with a length of 1. + + The operand x must be a scalar or vector whose component type is floating-point. + + Result Type and the type of x must be the same type. + + #### Example: + + ```mlir + %2 = spirv.GL.Normalize %0 : vector<3xf32> + %3 = spirv.GL.Normalize %1 : vector<4xf16> + ``` + }]; +} + +// ----- + +def SPIRV_GLReflectOp : SPIRV_GLBinaryArithmeticOp<"Reflect", 71, SPIRV_Float> { + let summary = "Calculate reflection direction vector"; + + let description = [{ + For the incident vector I and surface orientation N, the result is the reflection direction: + + I - 2 * dot(N, I) * N + + N must already be normalized in order to achieve the desired result. + + The operands must all be a scalar or vector whose component type is floating-point. + + Result Type and the type of all operands must be the same type. + + #### Example: + + ```mlir + %2 = spirv.GL.Reflect %0, %1 : f32 + %3 = spirv.GL.Reflect %0, %1 : vector<3xf32> + ``` + }]; +} + +// ---- + def SPIRV_GLFindUMsbOp : SPIRV_GLUnaryArithmeticOp<"FindUMsb", 75, SPIRV_Int32> { let summary = "Unsigned-integer most-significant bit"; diff --git a/mlir/test/Dialect/SPIRV/IR/gl-ops.mlir b/mlir/test/Dialect/SPIRV/IR/gl-ops.mlir index 3683e5b469b17..beda3872bc8d2 100644 --- a/mlir/test/Dialect/SPIRV/IR/gl-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/gl-ops.mlir @@ -541,3 +541,125 @@ func.func @findumsb(%arg0 : i64) -> () { %2 = spirv.GL.FindUMsb %arg0 : i64 return } + +// ----- + +//===----------------------------------------------------------------------===// +// spirv.GL.Distance +//===----------------------------------------------------------------------===// + +func.func @distance_scalar(%arg0 : f32, %arg1 : f32) { + // CHECK: spirv.GL.Distance {{%.*}}, {{%.*}} : f32, f32 -> f32 + %0 = spirv.GL.Distance %arg0, %arg1 : f32, f32 -> f32 + return +} + +func.func @distance_vector(%arg0 : vector<3xf32>, %arg1 : vector<3xf32>) { + // CHECK: spirv.GL.Distance {{%.*}}, {{%.*}} : vector<3xf32>, vector<3xf32> -> f32 + %0 = spirv.GL.Distance %arg0, %arg1 : vector<3xf32>, vector<3xf32> -> f32 + return +} + +// ----- + +func.func @distance_invalid_type(%arg0 : i32, %arg1 : i32) { + // expected-error @+1 {{'spirv.GL.Distance' op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values of length 2/3/4/8/16}} + %0 = spirv.GL.Distance %arg0, %arg1 : i32, i32 -> f32 + return +} + +// ----- + +func.func @distance_arg_mismatch(%arg0 : vector<3xf32>, %arg1 : vector<4xf32>) { + // expected-error @+1 {{'spirv.GL.Distance' op failed to verify that all of {p0, p1} have same type}} + %0 = spirv.GL.Distance %arg0, %arg1 : vector<3xf32>, vector<4xf32> -> f32 + return +} + +// ----- + +func.func @distance_invalid_vector_size(%arg0 : vector<5xf32>, %arg1 : vector<5xf32>) { + // expected-error @+1 {{'spirv.GL.Distance' op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values of length 2/3/4/8/16}} + %0 = spirv.GL.Distance %arg0, %arg1 : vector<5xf32>, vector<5xf32> -> f32 + return +} + +// ----- + +func.func @distance_invalid_result(%arg0 : f32, %arg1 : f32) { + // expected-error @+1 {{'spirv.GL.Distance' op result #0 must be 16/32/64-bit float}} + %0 = spirv.GL.Distance %arg0, %arg1 : f32, f32 -> i32 + return +} + +// ----- + +//===----------------------------------------------------------------------===// +// spirv.GL.Cross +//===----------------------------------------------------------------------===// + +func.func @cross(%arg0 : vector<3xf32>, %arg1 : vector<3xf32>) { + %2 = spirv.GL.Cross %arg0, %arg1 : vector<3xf32> + // CHECK: %{{.+}} = spirv.GL.Cross %{{.+}}, %{{.+}} : vector<3xf32> + return +} + +// ----- + +func.func @cross_invalid_type(%arg0 : vector<3xi32>, %arg1 : vector<3xi32>) { + // expected-error @+1 {{'spirv.GL.Cross' op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values of length 2/3/4/8/16, but got 'vector<3xi32>'}} + %0 = spirv.GL.Cross %arg0, %arg1 : vector<3xi32> + return +} + +// ----- + +//===----------------------------------------------------------------------===// +// spirv.GL.Normalize +//===----------------------------------------------------------------------===// + +func.func @normalize_scalar(%arg0 : f32) { + %2 = spirv.GL.Normalize %arg0 : f32 + // CHECK: %{{.+}} = spirv.GL.Normalize %{{.+}} : f32 + return +} + +func.func @normalize_vector(%arg0 : vector<3xf32>) { + %2 = spirv.GL.Normalize %arg0 : vector<3xf32> + // CHECK: %{{.+}} = spirv.GL.Normalize %{{.+}} : vector<3xf32> + return +} + +// ----- + +func.func @normalize_invalid_type(%arg0 : i32) { + // expected-error @+1 {{'spirv.GL.Normalize' op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}} + %0 = spirv.GL.Normalize %arg0 : i32 + return +} + +// ----- + +//===----------------------------------------------------------------------===// +// spirv.GL.Reflect +//===----------------------------------------------------------------------===// + +func.func @reflect_scalar(%arg0 : f32, %arg1 : f32) { + %2 = spirv.GL.Reflect %arg0, %arg1 : f32 + // CHECK: %{{.+}} = spirv.GL.Reflect %{{.+}}, %{{.+}} : f32 + return +} + +func.func @reflect_vector(%arg0 : vector<3xf32>, %arg1 : vector<3xf32>) { + %2 = spirv.GL.Reflect %arg0, %arg1 : vector<3xf32> + // CHECK: %{{.+}} = spirv.GL.Reflect %{{.+}}, %{{.+}} : vector<3xf32> + return +} + +// ----- + +func.func @reflect_invalid_type(%arg0 : i32, %arg1 : i32) { + // expected-error @+1 {{'spirv.GL.Reflect' op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}} + %0 = spirv.GL.Reflect %arg0, %arg1 : i32 + return +} diff --git a/mlir/test/Target/SPIRV/gl-ops.mlir b/mlir/test/Target/SPIRV/gl-ops.mlir index fff1adf0ae12c..119304cea7d4a 100644 --- a/mlir/test/Target/SPIRV/gl-ops.mlir +++ b/mlir/test/Target/SPIRV/gl-ops.mlir @@ -81,4 +81,24 @@ spirv.module Logical GLSL450 requires #spirv.vce { %2 = spirv.GL.FindUMsb %arg0 : i32 spirv.Return } + +spirv.func @vector(%arg0 : f32, %arg1 : vector<3xf32>, %arg2 : vector<3xf32>) "None" { + // CHECK: {{%.*}} = spirv.GL.Cross {{%.*}}, {{%.*}} : vector<3xf32> + %0 = spirv.GL.Cross %arg1, %arg2 : vector<3xf32> + // CHECK: {{%.*}} = spirv.GL.Normalize {{%.*}} : f32 + %1 = spirv.GL.Normalize %arg0 : f32 + // CHECK: {{%.*}} = spirv.GL.Normalize {{%.*}} : vector<3xf32> + %2 = spirv.GL.Normalize %arg1 : vector<3xf32> + // CHECK: {{%.*}} = spirv.GL.Reflect {{%.*}}, {{%.*}} : f32 + %3 = spirv.GL.Reflect %arg0, %arg0 : f32 + // CHECK: {{%.*}} = spirv.GL.Reflect {{%.*}}, {{%.*}} : vector<3xf32> + %4 = spirv.GL.Reflect %arg1, %arg2 : vector<3xf32> + // CHECK: {{%.*}} = spirv.GL.Distance {{%.*}}, {{%.*}} : f32, f32 -> f32 + %5 = spirv.GL.Distance %arg0, %arg0 : f32, f32 -> f32 + // CHECK: {{%.*}} = spirv.GL.Distance {{%.*}}, {{%.*}} : vector<3xf32>, vector<3xf32> -> f32 + %6 = spirv.GL.Distance %arg1, %arg2 : vector<3xf32>, vector<3xf32> -> f32 + spirv.Return + } + + }