diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td index a4c01c0bc3418..469a9a0ef01dd 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td @@ -4171,6 +4171,7 @@ def SPIRV_IsArrayType : CPred<"::llvm::isa<::mlir::spirv::ArrayType>($_self)">; def SPIRV_IsCooperativeMatrixType : CPred<"::llvm::isa<::mlir::spirv::CooperativeMatrixType>($_self)">; def SPIRV_IsImageType : CPred<"::llvm::isa<::mlir::spirv::ImageType>($_self)">; +def SPIRV_IsVectorType : CPred<"::llvm::isa<::mlir::VectorType>($_self)">; def SPIRV_IsMatrixType : CPred<"::llvm::isa<::mlir::spirv::MatrixType>($_self)">; def SPIRV_IsPtrType : CPred<"::llvm::isa<::mlir::spirv::PointerType>($_self)">; def SPIRV_IsRTArrayType : CPred<"::llvm::isa<::mlir::spirv::RuntimeArrayType>($_self)">; @@ -4202,6 +4203,8 @@ def SPIRV_AnyCooperativeMatrix : DialectType; def SPIRV_AnyImage : DialectType; +def SPIRV_AnyVector : DialectType; def SPIRV_AnyMatrix : DialectType; def SPIRV_AnyRTArray : DialectType; def SPIRV_OC_OpFMod : I32EnumAttrCase<"OpFMod", 141>; def SPIRV_OC_OpVectorTimesScalar : I32EnumAttrCase<"OpVectorTimesScalar", 142>; def SPIRV_OC_OpMatrixTimesScalar : I32EnumAttrCase<"OpMatrixTimesScalar", 143>; +def SPIRV_OC_OpMatrixTimesVector : I32EnumAttrCase<"OpMatrixTimesVector", 145>; def SPIRV_OC_OpMatrixTimesMatrix : I32EnumAttrCase<"OpMatrixTimesMatrix", 146>; def SPIRV_OC_OpDot : I32EnumAttrCase<"OpDot", 148>; def SPIRV_OC_OpIAddCarry : I32EnumAttrCase<"OpIAddCarry", 149>; @@ -4553,7 +4557,7 @@ def SPIRV_OpcodeAttr : SPIRV_OC_OpFSub, SPIRV_OC_OpIMul, SPIRV_OC_OpFMul, SPIRV_OC_OpUDiv, SPIRV_OC_OpSDiv, SPIRV_OC_OpFDiv, SPIRV_OC_OpUMod, SPIRV_OC_OpSRem, SPIRV_OC_OpSMod, SPIRV_OC_OpFRem, SPIRV_OC_OpFMod, - SPIRV_OC_OpVectorTimesScalar, SPIRV_OC_OpMatrixTimesScalar, + SPIRV_OC_OpVectorTimesScalar, SPIRV_OC_OpMatrixTimesScalar, SPIRV_OC_OpMatrixTimesVector, SPIRV_OC_OpMatrixTimesMatrix, SPIRV_OC_OpDot, SPIRV_OC_OpIAddCarry, SPIRV_OC_OpISubBorrow, SPIRV_OC_OpUMulExtended, SPIRV_OC_OpSMulExtended, SPIRV_OC_OpIsNan, SPIRV_OC_OpIsInf, SPIRV_OC_OpOrdered, SPIRV_OC_OpUnordered, diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMatrixOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMatrixOps.td index a6f0f41429bcb..5bd99386e0085 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMatrixOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMatrixOps.td @@ -114,6 +114,47 @@ def SPIRV_MatrixTimesScalarOp : SPIRV_Op< // ----- +def SPIRV_MatrixTimesVectorOp : SPIRV_Op<"MatrixTimesVector", [Pure]> { + let summary = "Linear-algebraic multiply of matrix X vector."; + + let description = [{ + Result Type must be a vector of floating-point type. + + Matrix must be an OpTypeMatrix whose Column Type is Result Type. + + Vector must be a vector with the same Component Type as the Component Type in Result Type. Its number of components must equal the number of columns in Matrix. + + #### Example: + + ```mlir + %0 = spirv.MatrixTimesVector %matrix, %vector : + !spirv.matrix<3 x vector<2xf32>>, vector<3xf32> -> vector<2xf32> + ``` + }]; + + let availability = [ + MinVersion, + MaxVersion, + Extension<[]>, + Capability<[SPIRV_C_Matrix]> + ]; + + let arguments = (ins + SPIRV_AnyMatrix:$matrix, + SPIRV_AnyVector:$vector + ); + + let results = (outs + SPIRV_AnyVector:$result + ); + + let assemblyFormat = [{ + operands attr-dict `:` type($matrix) `,` type($vector) `->` type($result) + }]; +} + +// ----- + def SPIRV_TransposeOp : SPIRV_Op<"Transpose", [Pure]> { let summary = "Transpose a matrix."; diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp index 26559c1321db5..040bf6a34cea7 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp @@ -1698,6 +1698,33 @@ LogicalResult spirv::TransposeOp::verify() { return success(); } +//===----------------------------------------------------------------------===// +// spirv.MatrixTimesVector +//===----------------------------------------------------------------------===// + +LogicalResult spirv::MatrixTimesVectorOp::verify() { + auto matrixType = llvm::cast(getMatrix().getType()); + auto vectorType = llvm::cast(getVector().getType()); + auto resultType = llvm::cast(getType()); + + if (matrixType.getNumColumns() != vectorType.getNumElements()) + return emitOpError("matrix columns (") + << matrixType.getNumColumns() << ") must match vector operand size (" + << vectorType.getNumElements() << ")"; + + if (resultType.getNumElements() != matrixType.getNumRows()) + return emitOpError("result size (") + << resultType.getNumElements() << ") must match the matrix rows (" + << matrixType.getNumRows() << ")"; + + auto matrixElementType = matrixType.getElementType(); + if (matrixElementType != vectorType.getElementType() || + matrixElementType != resultType.getElementType()) + return emitOpError("matrix, vector, and result element types must match"); + + return success(); +} + //===----------------------------------------------------------------------===// // spirv.MatrixTimesMatrix //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SPIRV/IR/matrix-ops.mlir b/mlir/test/Dialect/SPIRV/IR/matrix-ops.mlir index 372fcc6e514b9..37e7514d664ef 100644 --- a/mlir/test/Dialect/SPIRV/IR/matrix-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/matrix-ops.mlir @@ -29,6 +29,13 @@ spirv.module Logical GLSL450 requires #spirv.vce { spirv.ReturnValue %result : !spirv.matrix<3 x vector<3xf32>> } + // CHECK-LABEL: @matrix_times_vector_1 + spirv.func @matrix_times_vector_1(%arg0: !spirv.matrix<3 x vector<4xf32>>, %arg1: vector<3xf32>) -> vector<4xf32> "None" { + // CHECK: {{%.*}} = spirv.MatrixTimesVector {{%.*}}, {{%.*}} : !spirv.matrix<3 x vector<4xf32>>, vector<3xf32> -> vector<4xf32> + %result = spirv.MatrixTimesVector %arg0, %arg1 : !spirv.matrix<3 x vector<4xf32>>, vector<3xf32> -> vector<4xf32> + spirv.ReturnValue %result : vector<4xf32> + } + // CHECK-LABEL: @matrix_times_matrix_1 spirv.func @matrix_times_matrix_1(%arg0: !spirv.matrix<3 x vector<3xf32>>, %arg1: !spirv.matrix<3 x vector<3xf32>>) -> !spirv.matrix<3 x vector<3xf32>> "None"{ // CHECK: {{%.*}} = spirv.MatrixTimesMatrix {{%.*}}, {{%.*}} : !spirv.matrix<3 x vector<3xf32>>, !spirv.matrix<3 x vector<3xf32>> -> !spirv.matrix<3 x vector<3xf32>> @@ -124,3 +131,27 @@ func.func @matrix_times_matrix_component_type_mismatch_2(%arg0 : !spirv.matrix<3 %result = spirv.MatrixTimesMatrix %arg0, %arg1 : !spirv.matrix<3 x vector<3xf64>>, !spirv.matrix<3 x vector<3xf32>> -> !spirv.matrix<3 x vector<3xf32>> return } + +// ----- + +func.func @matrix_times_vector_element_type_mismatch(%arg0: !spirv.matrix<4 x vector<3xf32>>, %arg1: vector<4xf16>) { + // expected-error @+1 {{matrix, vector, and result element types must match}} + %result = spirv.MatrixTimesVector %arg0, %arg1 : !spirv.matrix<4 x vector<3xf32>>, vector<4xf16> -> vector<3xf32> + return +} + +// ----- + +func.func @matrix_times_vector_row_mismatch(%arg0: !spirv.matrix<4 x vector<3xf32>>, %arg1: vector<4xf32>) { + // expected-error @+1 {{spirv.MatrixTimesVector' op result size (4) must match the matrix rows (3)}} + %result = spirv.MatrixTimesVector %arg0, %arg1 : !spirv.matrix<4 x vector<3xf32>>, vector<4xf32> -> vector<4xf32> + return +} + +// ----- + +func.func @matrix_times_vector_column_mismatch(%arg0: !spirv.matrix<4 x vector<3xf32>>, %arg1: vector<3xf32>) { + // expected-error @+1 {{spirv.MatrixTimesVector' op matrix columns (4) must match vector operand size (3)}} + %result = spirv.MatrixTimesVector %arg0, %arg1 : !spirv.matrix<4 x vector<3xf32>>, vector<3xf32> -> vector<3xf32> + return +} diff --git a/mlir/test/Target/SPIRV/matrix.mlir b/mlir/test/Target/SPIRV/matrix.mlir index 2a391df4bff39..0ec1dc27e4e93 100644 --- a/mlir/test/Target/SPIRV/matrix.mlir +++ b/mlir/test/Target/SPIRV/matrix.mlir @@ -36,6 +36,13 @@ spirv.module Logical GLSL450 requires #spirv.vce { spirv.ReturnValue %result : !spirv.matrix<2 x vector<3xf32>> } + // CHECK-LABEL: @matrix_times_vector_1 + spirv.func @matrix_times_vector_1(%arg0: !spirv.matrix<3 x vector<4xf32>>, %arg1: vector<3xf32>) -> vector<4xf32> "None" { + // CHECK: {{%.*}} = spirv.MatrixTimesVector {{%.*}}, {{%.*}} : !spirv.matrix<3 x vector<4xf32>>, vector<3xf32> -> vector<4xf32> + %result = spirv.MatrixTimesVector %arg0, %arg1 : !spirv.matrix<3 x vector<4xf32>>, vector<3xf32> -> vector<4xf32> + spirv.ReturnValue %result : vector<4xf32> + } + // CHECK-LABEL: @matrix_times_matrix_1 spirv.func @matrix_times_matrix_1(%arg0: !spirv.matrix<3 x vector<3xf32>>, %arg1: !spirv.matrix<3 x vector<3xf32>>) -> !spirv.matrix<3 x vector<3xf32>> "None"{ // CHECK: {{%.*}} = spirv.MatrixTimesMatrix {{%.*}}, {{%.*}} : !spirv.matrix<3 x vector<3xf32>>, !spirv.matrix<3 x vector<3xf32>> -> !spirv.matrix<3 x vector<3xf32>>