Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -445,12 +445,12 @@ def SPIRV_DotOp : SPIRV_Op<"Dot",
}];

let arguments = (ins
SPIRV_VectorOf<SPIRV_Float>:$vector1,
SPIRV_VectorOf<SPIRV_Float>:$vector2
SPIRV_VectorOf<SPIRV_AnyFloat>:$vector1,
SPIRV_VectorOf<SPIRV_AnyFloat>:$vector2
);

let results = (outs
SPIRV_Float:$result
SPIRV_AnyFloat:$result
);

let assemblyFormat = "operands attr-dict `:` type($vector1) `->` type($result)";
Expand Down
11 changes: 10 additions & 1 deletion mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,15 @@ func.func @dot(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> f32 {

// -----

// CHECK-LABEL: @dot_bf16
func.func @dot_bf16(%arg0: vector<4xbf16>, %arg1: vector<4xbf16>) -> bf16 {
// CHECK: spirv.Dot %{{.+}}, %{{.+}} : vector<4xbf16> -> bf16
%0 = spirv.Dot %arg0, %arg1 : vector<4xbf16> -> bf16
return %0 : bf16
}

// -----

// expected-note @+1 {{prior use here}}
func.func @dot(%arg0: vector<4xf32>, %arg1: vector<3xf32>) -> f32 {
// expected-error @+1 {{use of value '%arg1' expects different type than prior uses}}
Expand All @@ -339,7 +348,7 @@ func.func @dot(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> f16 {
// -----

func.func @dot(%arg0: vector<4xi32>, %arg1: vector<4xi32>) -> i32 {
// expected-error @+1 {{'spirv.Dot' op operand #0 must be vector of 16/32/64-bit float values of length 2/3/4/8/16}}
// expected-error @+1 {{'spirv.Dot' op operand #0 must be vector of 16/32/64-bit float or BFloat16 values of length 2/3/4/8/16}}
%0 = spirv.Dot %arg0, %arg1 : vector<4xi32> -> i32
return %0 : i32
}
Expand Down
5 changes: 5 additions & 0 deletions mlir/test/Target/SPIRV/arithmetic-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -86,4 +86,9 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
%0 = spirv.VectorTimesScalar %arg0, %arg1 : (vector<4xf32>, f32) -> vector<4xf32>
spirv.Return
}
spirv.func @dot_bf16(%arg0: vector<4xbf16>, %arg1: vector<4xbf16>) "None" {
// CHECK: spirv.Dot %{{.+}}, %{{.+}} : vector<4xbf16> -> bf16
%0 = spirv.Dot %arg0, %arg1 : vector<4xbf16> -> bf16
spirv.Return
}
}