Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 38 additions & 19 deletions mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,25 @@ class SPIRV_ArithmeticBinaryOp<string mnemonic, Type type,
[Pure, AllTypesMatch<["operand1", "operand2", "result"]>])> {
// In addition to normal types arithmetic instructions can support cooperative
// matrix.
let arguments = (ins
SPIRV_ScalarOrVectorOf<type>:$operand1,
SPIRV_ScalarOrVectorOf<type>:$operand2
);

let results = (outs
SPIRV_ScalarOrVectorOf<type>:$result
);
let assemblyFormat = "operands attr-dict `:` type($result)";
}

class SPIRV_ArithmeticBinaryOpWithCoopMatrix<string mnemonic, Type type,
list<Trait> traits = []> :
// Operands type same as result type.
SPIRV_BinaryOp<mnemonic, type, type,
!listconcat(traits,
[Pure, AllTypesMatch<["operand1", "operand2", "result"]>])> {
// In addition to normal types these arithmetic instructions can support
// cooperative matrix.
let arguments = (ins
SPIRV_ScalarOrVectorOrCoopMatrixOf<type>:$operand1,
SPIRV_ScalarOrVectorOrCoopMatrixOf<type>:$operand2
Expand Down Expand Up @@ -82,7 +101,7 @@ class SPIRV_ArithmeticExtendedBinaryOp<string mnemonic,

// -----

def SPIRV_FAddOp : SPIRV_ArithmeticBinaryOp<"FAdd", SPIRV_Float, [Commutative]> {
def SPIRV_FAddOp : SPIRV_ArithmeticBinaryOpWithCoopMatrix<"FAdd", SPIRV_Float, [Commutative]> {
let summary = "Floating-point addition of Operand 1 and Operand 2.";

let description = [{
Expand All @@ -104,7 +123,7 @@ def SPIRV_FAddOp : SPIRV_ArithmeticBinaryOp<"FAdd", SPIRV_Float, [Commutative]>

// -----

def SPIRV_FDivOp : SPIRV_ArithmeticBinaryOp<"FDiv", SPIRV_Float, []> {
def SPIRV_FDivOp : SPIRV_ArithmeticBinaryOpWithCoopMatrix<"FDiv", SPIRV_Float, []> {
let summary = "Floating-point division of Operand 1 divided by Operand 2.";

let description = [{
Expand Down Expand Up @@ -154,7 +173,7 @@ def SPIRV_FModOp : SPIRV_ArithmeticBinaryOp<"FMod", SPIRV_Float, []> {

// -----

def SPIRV_FMulOp : SPIRV_ArithmeticBinaryOp<"FMul", SPIRV_Float, [Commutative]> {
def SPIRV_FMulOp : SPIRV_ArithmeticBinaryOpWithCoopMatrix<"FMul", SPIRV_Float, [Commutative]> {
let summary = "Floating-point multiplication of Operand 1 and Operand 2.";

let description = [{
Expand Down Expand Up @@ -229,7 +248,7 @@ def SPIRV_FRemOp : SPIRV_ArithmeticBinaryOp<"FRem", SPIRV_Float, []> {

// -----

def SPIRV_FSubOp : SPIRV_ArithmeticBinaryOp<"FSub", SPIRV_Float, []> {
def SPIRV_FSubOp : SPIRV_ArithmeticBinaryOpWithCoopMatrix<"FSub", SPIRV_Float, []> {
let summary = "Floating-point subtraction of Operand 2 from Operand 1.";

let description = [{
Expand All @@ -251,9 +270,9 @@ def SPIRV_FSubOp : SPIRV_ArithmeticBinaryOp<"FSub", SPIRV_Float, []> {

// -----

def SPIRV_IAddOp : SPIRV_ArithmeticBinaryOp<"IAdd",
SPIRV_Integer,
[Commutative, UsableInSpecConstantOp]> {
def SPIRV_IAddOp : SPIRV_ArithmeticBinaryOpWithCoopMatrix<"IAdd",
SPIRV_Integer,
[Commutative, UsableInSpecConstantOp]> {
let summary = "Integer addition of Operand 1 and Operand 2.";

let description = [{
Expand Down Expand Up @@ -322,9 +341,9 @@ def SPIRV_IAddCarryOp : SPIRV_ArithmeticExtendedBinaryOp<"IAddCarry",

// -----

def SPIRV_IMulOp : SPIRV_ArithmeticBinaryOp<"IMul",
SPIRV_Integer,
[Commutative, UsableInSpecConstantOp]> {
def SPIRV_IMulOp : SPIRV_ArithmeticBinaryOpWithCoopMatrix<"IMul",
SPIRV_Integer,
[Commutative, UsableInSpecConstantOp]> {
let summary = "Integer multiplication of Operand 1 and Operand 2.";

let description = [{
Expand Down Expand Up @@ -354,9 +373,9 @@ def SPIRV_IMulOp : SPIRV_ArithmeticBinaryOp<"IMul",

// -----

def SPIRV_ISubOp : SPIRV_ArithmeticBinaryOp<"ISub",
SPIRV_Integer,
[UsableInSpecConstantOp]> {
def SPIRV_ISubOp : SPIRV_ArithmeticBinaryOpWithCoopMatrix<"ISub",
SPIRV_Integer,
[UsableInSpecConstantOp]> {
let summary = "Integer subtraction of Operand 2 from Operand 1.";

let description = [{
Expand Down Expand Up @@ -460,9 +479,9 @@ def SPIRV_DotOp : SPIRV_Op<"Dot",

// -----

def SPIRV_SDivOp : SPIRV_ArithmeticBinaryOp<"SDiv",
SPIRV_Integer,
[UsableInSpecConstantOp]> {
def SPIRV_SDivOp : SPIRV_ArithmeticBinaryOpWithCoopMatrix<"SDiv",
SPIRV_Integer,
[UsableInSpecConstantOp]> {
let summary = "Signed-integer division of Operand 1 divided by Operand 2.";

let description = [{
Expand Down Expand Up @@ -622,9 +641,9 @@ def SPIRV_SRemOp : SPIRV_ArithmeticBinaryOp<"SRem",

// -----

def SPIRV_UDivOp : SPIRV_ArithmeticBinaryOp<"UDiv",
SPIRV_Integer,
[UnsignedOp, UsableInSpecConstantOp]> {
def SPIRV_UDivOp : SPIRV_ArithmeticBinaryOpWithCoopMatrix<"UDiv",
SPIRV_Integer,
[UnsignedOp, UsableInSpecConstantOp]> {
let summary = "Unsigned-integer division of Operand 1 divided by Operand 2.";

let description = [{
Expand Down
43 changes: 43 additions & 0 deletions mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -549,3 +549,46 @@ spirv.func @matrix_times_scalar(%a: !spirv.coopmatrix<2x2xf32, Workgroup, Matrix
%p = spirv.MatrixTimesScalar %a, %b : !spirv.coopmatrix<2x2xf32, Workgroup, MatrixA>, f16
spirv.Return
}

// -----

// These binary arithmetic instructions do not support coop matrix operands.

spirv.func @fmod(%a: !spirv.coopmatrix<2x2xf32, Subgroup, MatrixA>, %b: !spirv.coopmatrix<2x2xf32, Subgroup, MatrixA>) "None" {
// expected-error @+1 {{op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values of length 2/3/4/8/16}}
%p = spirv.FMod %a, %b : !spirv.coopmatrix<2x2xf32, Subgroup, MatrixA>
spirv.Return
}

// -----

spirv.func @frem(%a: !spirv.coopmatrix<2x2xf32, Subgroup, MatrixA>, %b: !spirv.coopmatrix<2x2xf32, Subgroup, MatrixA>) "None" {
// expected-error @+1 {{op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values of length 2/3/4/8/16}}
%p = spirv.FRem %a, %b : !spirv.coopmatrix<2x2xf32, Subgroup, MatrixA>
spirv.Return
}

// -----
spirv.func @smod(%a: !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>, %b: !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>) "None" {
// expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values of length 2/3/4/8/16}}
%p = spirv.SMod %a, %b : !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>
spirv.Return
}

// -----

spirv.func @srem(%a: !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>, %b: !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>) "None" {
// expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values of length 2/3/4/8/16}}
%p = spirv.SRem %a, %b : !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>
spirv.Return
}

// -----

spirv.func @umod(%a: !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>, %b: !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>) "None" {
// expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values of length 2/3/4/8/16}}
%p = spirv.UMod %a, %b : !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>
spirv.Return
}

// -----
Loading