Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -445,16 +445,19 @@ def SPIRV_DotOp : SPIRV_Op<"Dot",
}];

let arguments = (ins
SPIRV_VectorOf<SPIRV_Float>:$vector1,
SPIRV_VectorOf<SPIRV_Float>:$vector2
SPIRV_VectorOf<SPIRV_AnyFloat>:$vector1,
SPIRV_VectorOf<SPIRV_AnyFloat>:$vector2
);

let results = (outs
SPIRV_Float:$result
SPIRV_AnyFloat:$result
);

let assemblyFormat = "operands attr-dict `:` type($vector1) `->` type($result)";

// Require dynamic availability specification based on operand/result type.
bit autogenAvailability = 0;

let hasVerifier = 0;
}

Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ add_mlir_dialect_library(MLIRSPIRVDialect
CastOps.cpp
ControlFlowOps.cpp
CooperativeMatrixOps.cpp
DotProductOps.cpp
GroupOps.cpp
ImageOps.cpp
IntegerDotProductOps.cpp
MemoryOps.cpp
MeshOps.cpp
SPIRVAttributes.cpp
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
//===- IntegerDotProductOps.cpp - MLIR SPIR-V Integer Dot Product Ops ----===//
//===- DotProductOps.cpp - MLIR SPIR-V Dot Product Ops ----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Defines the Integer Dot Product operations in the SPIR-V dialect.
// Defines the Dot Product operations in the SPIR-V dialect.
//
//===----------------------------------------------------------------------===//

Expand All @@ -21,6 +21,44 @@ using namespace mlir::spirv::AttrNames;

namespace mlir::spirv {

//===----------------------------------------------------------------------===//
// Dot Product ops
//===----------------------------------------------------------------------===//

static std::optional<spirv::Version> getDotProductMinVersion() {
return spirv::Version::V_1_0; // Available in SPIR-V >= 1.0.
}

static std::optional<spirv::Version> getDotProductMaxVersion() {
return spirv::Version::V_1_6; // Available in SPIR-V <= 1.6.
}
Comment on lines +32 to +34
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not like hardcoding the maximum version here - is there currently a way to retrieve the "default" maximum spirv version? From my understanding the only place this is defined is inside the availability field inside the tablegen definition for SPIRV_Op, and there is no way to retrieve that here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not aware of anything better. We could probably have a define for spirv::Version::Latest but in tests we'd still have to test for the exact version.


SmallVector<ArrayRef<spirv::Extension>, 1> DotOp::getExtensions() {
if (getResult().getType().isBF16()) {
static const auto extension = spirv::Extension::SPV_KHR_bfloat16;
return {extension};
}

return {};
}

SmallVector<ArrayRef<spirv::Capability>, 1> DotOp::getCapabilities() {
if (getResult().getType().isBF16()) {
static const auto capability = spirv::Capability::BFloat16DotProductKHR;
return {capability};
}

return {};
}

std::optional<spirv::Version> DotOp::getMinVersion() {
return getDotProductMinVersion();
}

std::optional<spirv::Version> DotOp::getMaxVersion() {
return getDotProductMaxVersion();
}

//===----------------------------------------------------------------------===//
// Integer Dot Product ops
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -71,14 +109,6 @@ static LogicalResult verifyIntegerDotProduct(Operation *op) {
return success();
}

static std::optional<spirv::Version> getIntegerDotProductMinVersion() {
return spirv::Version::V_1_0; // Available in SPIR-V >= 1.0.
}

static std::optional<spirv::Version> getIntegerDotProductMaxVersion() {
return spirv::Version::V_1_6; // Available in SPIR-V <= 1.6.
}

static SmallVector<ArrayRef<spirv::Extension>, 1>
getIntegerDotProductExtensions() {
// Requires the SPV_KHR_integer_dot_product extension, specified either
Expand Down Expand Up @@ -136,10 +166,10 @@ getIntegerDotProductCapabilities(Operation *op) {
return getIntegerDotProductCapabilities<OpName>(*this); \
} \
std::optional<spirv::Version> OpName::getMinVersion() { \
return getIntegerDotProductMinVersion(); \
return getDotProductMinVersion(); \
} \
std::optional<spirv::Version> OpName::getMaxVersion() { \
return getIntegerDotProductMaxVersion(); \
return getDotProductMaxVersion(); \
}

SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(SDotOp)
Expand Down
11 changes: 10 additions & 1 deletion mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,15 @@ func.func @dot(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> f32 {

// -----

// CHECK-LABEL: @dot_bf16
func.func @dot_bf16(%arg0: vector<4xbf16>, %arg1: vector<4xbf16>) -> bf16 {
// CHECK: spirv.Dot %{{.+}}, %{{.+}} : vector<4xbf16> -> bf16
%0 = spirv.Dot %arg0, %arg1 : vector<4xbf16> -> bf16
return %0 : bf16
}

// -----

// expected-note @+1 {{prior use here}}
func.func @dot(%arg0: vector<4xf32>, %arg1: vector<3xf32>) -> f32 {
// expected-error @+1 {{use of value '%arg1' expects different type than prior uses}}
Expand All @@ -339,7 +348,7 @@ func.func @dot(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> f16 {
// -----

func.func @dot(%arg0: vector<4xi32>, %arg1: vector<4xi32>) -> i32 {
// expected-error @+1 {{'spirv.Dot' op operand #0 must be vector of 16/32/64-bit float values of length 2/3/4/8/16}}
// expected-error @+1 {{'spirv.Dot' op operand #0 must be vector of 16/32/64-bit float or BFloat16 values of length 2/3/4/8/16}}
%0 = spirv.Dot %arg0, %arg1 : vector<4xi32> -> i32
return %0 : i32
}
Expand Down
14 changes: 14 additions & 0 deletions mlir/test/Dialect/SPIRV/IR/availability.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,20 @@ func.func @udot_acc_sat_vector_4xi16_i64(%a: vector<4xi16>, %acc: i64) -> i64 {
return %r: i64
}

//===----------------------------------------------------------------------===//
// Dot Product op with bfloat16
//===----------------------------------------------------------------------===//

// CHECK-LABEL: dot_vector_4xbf16_bf16
func.func @dot_vector_4xbf16_bf16(%a: vector<4xbf16>, %b: vector<4xbf16>) -> bf16 {
// CHECK: min version: v1.0
// CHECK: max version: v1.6
// CHECK: extensions: [ [SPV_KHR_bfloat16] ]
// CHECK: capabilities: [ [BFloat16DotProductKHR] ]
%r = spirv.Dot %a, %a: vector<4xbf16> -> bf16
return %r: bf16
}

//===----------------------------------------------------------------------===//
// Primitive ops
//===----------------------------------------------------------------------===//
Expand Down
5 changes: 5 additions & 0 deletions mlir/test/Target/SPIRV/arithmetic-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -86,4 +86,9 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
%0 = spirv.VectorTimesScalar %arg0, %arg1 : (vector<4xf32>, f32) -> vector<4xf32>
spirv.Return
}
spirv.func @dot_bf16(%arg0: vector<4xbf16>, %arg1: vector<4xbf16>) "None" {
// CHECK: spirv.Dot %{{.+}}, %{{.+}} : vector<4xbf16> -> bf16
%0 = spirv.Dot %arg0, %arg1 : vector<4xbf16> -> bf16
spirv.Return
}
}
Loading