Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -4171,6 +4171,7 @@ def SPIRV_IsArrayType : CPred<"::llvm::isa<::mlir::spirv::ArrayType>($_self)">;
def SPIRV_IsCooperativeMatrixType :
CPred<"::llvm::isa<::mlir::spirv::CooperativeMatrixType>($_self)">;
def SPIRV_IsImageType : CPred<"::llvm::isa<::mlir::spirv::ImageType>($_self)">;
def SPIRV_IsVectorType : CPred<"::llvm::isa<::mlir::VectorType>($_self)">;
def SPIRV_IsMatrixType : CPred<"::llvm::isa<::mlir::spirv::MatrixType>($_self)">;
def SPIRV_IsPtrType : CPred<"::llvm::isa<::mlir::spirv::PointerType>($_self)">;
def SPIRV_IsRTArrayType : CPred<"::llvm::isa<::mlir::spirv::RuntimeArrayType>($_self)">;
Expand Down Expand Up @@ -4202,6 +4203,8 @@ def SPIRV_AnyCooperativeMatrix : DialectType<SPIRV_Dialect,
"any SPIR-V cooperative matrix type">;
def SPIRV_AnyImage : DialectType<SPIRV_Dialect, SPIRV_IsImageType,
"any SPIR-V image type">;
def SPIRV_AnyVector : DialectType<SPIRV_Dialect, SPIRV_IsVectorType,
"any SPIR-V vector type">;
def SPIRV_AnyMatrix : DialectType<SPIRV_Dialect, SPIRV_IsMatrixType,
"any SPIR-V matrix type">;
def SPIRV_AnyRTArray : DialectType<SPIRV_Dialect, SPIRV_IsRTArrayType,
Expand Down Expand Up @@ -4384,6 +4387,7 @@ def SPIRV_OC_OpFRem : I32EnumAttrCase<"OpFRem", 140>;
def SPIRV_OC_OpFMod : I32EnumAttrCase<"OpFMod", 141>;
def SPIRV_OC_OpVectorTimesScalar : I32EnumAttrCase<"OpVectorTimesScalar", 142>;
def SPIRV_OC_OpMatrixTimesScalar : I32EnumAttrCase<"OpMatrixTimesScalar", 143>;
def SPIRV_OC_OpMatrixTimesVector : I32EnumAttrCase<"OpMatrixTimesVector", 145>;
def SPIRV_OC_OpMatrixTimesMatrix : I32EnumAttrCase<"OpMatrixTimesMatrix", 146>;
def SPIRV_OC_OpDot : I32EnumAttrCase<"OpDot", 148>;
def SPIRV_OC_OpIAddCarry : I32EnumAttrCase<"OpIAddCarry", 149>;
Expand Down Expand Up @@ -4553,7 +4557,7 @@ def SPIRV_OpcodeAttr :
SPIRV_OC_OpFSub, SPIRV_OC_OpIMul, SPIRV_OC_OpFMul, SPIRV_OC_OpUDiv,
SPIRV_OC_OpSDiv, SPIRV_OC_OpFDiv, SPIRV_OC_OpUMod, SPIRV_OC_OpSRem,
SPIRV_OC_OpSMod, SPIRV_OC_OpFRem, SPIRV_OC_OpFMod,
SPIRV_OC_OpVectorTimesScalar, SPIRV_OC_OpMatrixTimesScalar,
SPIRV_OC_OpVectorTimesScalar, SPIRV_OC_OpMatrixTimesScalar, SPIRV_OC_OpMatrixTimesVector,
SPIRV_OC_OpMatrixTimesMatrix, SPIRV_OC_OpDot, SPIRV_OC_OpIAddCarry,
SPIRV_OC_OpISubBorrow, SPIRV_OC_OpUMulExtended, SPIRV_OC_OpSMulExtended,
SPIRV_OC_OpIsNan, SPIRV_OC_OpIsInf, SPIRV_OC_OpOrdered, SPIRV_OC_OpUnordered,
Expand Down
41 changes: 41 additions & 0 deletions mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMatrixOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,47 @@ def SPIRV_MatrixTimesScalarOp : SPIRV_Op<

// -----

def SPIRV_MatrixTimesVectorOp : SPIRV_Op<"MatrixTimesVector", [Pure]> {
let summary = "Linear-algebraic multiply of matrix X vector.";

let description = [{
Result Type must be a vector of floating-point type.

Matrix must be an OpTypeMatrix whose Column Type is Result Type.

Vector must be a vector with the same Component Type as the Component Type in Result Type. Its number of components must equal the number of columns in Matrix.

#### Example:

```mlir
%0 = spirv.MatrixTimesVector %matrix, %vector :
!spirv.matrix<3 x vector<2xf32>>, vector<3xf32> -> vector<2xf32>
```
}];

let availability = [
MinVersion<SPIRV_V_1_0>,
MaxVersion<SPIRV_V_1_6>,
Extension<[]>,
Capability<[SPIRV_C_Matrix]>
];

let arguments = (ins
SPIRV_AnyMatrix:$matrix,
SPIRV_AnyVector:$vector
);

let results = (outs
SPIRV_AnyVector:$result
);

let assemblyFormat = [{
operands attr-dict `:` type($matrix) `,` type($vector) `->` type($result)
}];
}

// -----

def SPIRV_TransposeOp : SPIRV_Op<"Transpose", [Pure]> {
let summary = "Transpose a matrix.";

Expand Down
27 changes: 27 additions & 0 deletions mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1698,6 +1698,33 @@ LogicalResult spirv::TransposeOp::verify() {
return success();
}

//===----------------------------------------------------------------------===//
// spirv.MatrixTimesVector
//===----------------------------------------------------------------------===//

LogicalResult spirv::MatrixTimesVectorOp::verify() {
auto matrixType = llvm::cast<spirv::MatrixType>(getMatrix().getType());
auto vectorType = llvm::cast<VectorType>(getVector().getType());
auto resultType = llvm::cast<VectorType>(getType());

if (matrixType.getNumColumns() != vectorType.getNumElements())
return emitOpError("matrix columns (")
<< matrixType.getNumColumns() << ") must match vector operand size ("
<< vectorType.getNumElements() << ")";

if (resultType.getNumElements() != matrixType.getNumRows())
return emitOpError("result size (")
<< resultType.getNumElements() << ") must match the matrix rows ("
<< matrixType.getNumRows() << ")";

auto matrixElementType = matrixType.getElementType();
if (matrixElementType != vectorType.getElementType() ||
matrixElementType != resultType.getElementType())
return emitOpError("matrix, vector, and result element types must match");

return success();
}

//===----------------------------------------------------------------------===//
// spirv.MatrixTimesMatrix
//===----------------------------------------------------------------------===//
Expand Down
31 changes: 31 additions & 0 deletions mlir/test/Dialect/SPIRV/IR/matrix-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,13 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
spirv.ReturnValue %result : !spirv.matrix<3 x vector<3xf32>>
}

// CHECK-LABEL: @matrix_times_vector_1
spirv.func @matrix_times_vector_1(%arg0: !spirv.matrix<3 x vector<4xf32>>, %arg1: vector<3xf32>) -> vector<4xf32> "None" {
// CHECK: {{%.*}} = spirv.MatrixTimesVector {{%.*}}, {{%.*}} : !spirv.matrix<3 x vector<4xf32>>, vector<3xf32> -> vector<4xf32>
%result = spirv.MatrixTimesVector %arg0, %arg1 : !spirv.matrix<3 x vector<4xf32>>, vector<3xf32> -> vector<4xf32>
spirv.ReturnValue %result : vector<4xf32>
}

// CHECK-LABEL: @matrix_times_matrix_1
spirv.func @matrix_times_matrix_1(%arg0: !spirv.matrix<3 x vector<3xf32>>, %arg1: !spirv.matrix<3 x vector<3xf32>>) -> !spirv.matrix<3 x vector<3xf32>> "None"{
// CHECK: {{%.*}} = spirv.MatrixTimesMatrix {{%.*}}, {{%.*}} : !spirv.matrix<3 x vector<3xf32>>, !spirv.matrix<3 x vector<3xf32>> -> !spirv.matrix<3 x vector<3xf32>>
Expand Down Expand Up @@ -124,3 +131,27 @@ func.func @matrix_times_matrix_component_type_mismatch_2(%arg0 : !spirv.matrix<3
%result = spirv.MatrixTimesMatrix %arg0, %arg1 : !spirv.matrix<3 x vector<3xf64>>, !spirv.matrix<3 x vector<3xf32>> -> !spirv.matrix<3 x vector<3xf32>>
return
}

// -----

func.func @matrix_times_vector_element_type_mismatch(%arg0: !spirv.matrix<4 x vector<3xf32>>, %arg1: vector<4xf16>) {
// expected-error @+1 {{matrix, vector, and result element types must match}}
%result = spirv.MatrixTimesVector %arg0, %arg1 : !spirv.matrix<4 x vector<3xf32>>, vector<4xf16> -> vector<3xf32>
return
}

// -----

func.func @matrix_times_vector_row_mismatch(%arg0: !spirv.matrix<4 x vector<3xf32>>, %arg1: vector<4xf32>) {
// expected-error @+1 {{spirv.MatrixTimesVector' op result size (4) must match the matrix rows (3)}}
%result = spirv.MatrixTimesVector %arg0, %arg1 : !spirv.matrix<4 x vector<3xf32>>, vector<4xf32> -> vector<4xf32>
return
}

// -----

func.func @matrix_times_vector_column_mismatch(%arg0: !spirv.matrix<4 x vector<3xf32>>, %arg1: vector<3xf32>) {
// expected-error @+1 {{spirv.MatrixTimesVector' op matrix columns (4) must match vector operand size (3)}}
%result = spirv.MatrixTimesVector %arg0, %arg1 : !spirv.matrix<4 x vector<3xf32>>, vector<3xf32> -> vector<3xf32>
return
}
14 changes: 14 additions & 0 deletions mlir/test/Target/SPIRV/matrix.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,20 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
spirv.ReturnValue %result : !spirv.matrix<2 x vector<3xf32>>
}

// CHECK-LABEL: @matrix_times_vector_1
spirv.func @matrix_times_vector_1(%arg0: !spirv.matrix<3 x vector<4xf32>>, %arg1: vector<3xf32>) -> vector<4xf32> "None" {
// CHECK: {{%.*}} = spirv.MatrixTimesVector {{%.*}}, {{%.*}} : !spirv.matrix<3 x vector<4xf32>>, vector<3xf32> -> vector<4xf32>
%result = spirv.MatrixTimesVector %arg0, %arg1 : !spirv.matrix<3 x vector<4xf32>>, vector<3xf32> -> vector<4xf32>
spirv.ReturnValue %result : vector<4xf32>
}

// CHECK-LABEL: @matrix_times_vector_2
spirv.func @matrix_times_vector_2(%arg0: vector<3xf32>, %arg1: !spirv.matrix<3 x vector<4xf32>>) -> vector<4xf32> "None" {
// CHECK: {{%.*}} = spirv.MatrixTimesVector {{%.*}}, {{%.*}} : !spirv.matrix<3 x vector<4xf32>>, vector<3xf32> -> vector<4xf32>
%result = spirv.MatrixTimesVector %arg1, %arg0 : !spirv.matrix<3 x vector<4xf32>>, vector<3xf32> -> vector<4xf32>
spirv.ReturnValue %result : vector<4xf32>
}

// CHECK-LABEL: @matrix_times_matrix_1
spirv.func @matrix_times_matrix_1(%arg0: !spirv.matrix<3 x vector<3xf32>>, %arg1: !spirv.matrix<3 x vector<3xf32>>) -> !spirv.matrix<3 x vector<3xf32>> "None"{
// CHECK: {{%.*}} = spirv.MatrixTimesMatrix {{%.*}}, {{%.*}} : !spirv.matrix<3 x vector<3xf32>>, !spirv.matrix<3 x vector<3xf32>> -> !spirv.matrix<3 x vector<3xf32>>
Expand Down
Loading