1- // ===- IntegerDotProductOps .cpp - MLIR SPIR-V Integer Dot Product Ops ----===//
1+ // ===- DotProductOps .cpp - MLIR SPIR-V Dot Product Ops --------------- ----===//
22//
33// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44// See https://llvm.org/LICENSE.txt for license information.
55// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66//
77// ===----------------------------------------------------------------------===//
88//
9- // Defines the Integer Dot Product operations in the SPIR-V dialect.
9+ // Defines the Dot Product operations in the SPIR-V dialect.
1010//
1111// ===----------------------------------------------------------------------===//
1212
@@ -21,6 +21,44 @@ using namespace mlir::spirv::AttrNames;
2121
2222namespace mlir ::spirv {
2323
24+ // ===----------------------------------------------------------------------===//
25+ // Dot Product ops
26+ // ===----------------------------------------------------------------------===//
27+
28+ static std::optional<spirv::Version> getDotProductMinVersion () {
29+ return spirv::Version::V_1_0; // Available in SPIR-V >= 1.0.
30+ }
31+
32+ static std::optional<spirv::Version> getDotProductMaxVersion () {
33+ return spirv::Version::V_1_6; // Available in SPIR-V <= 1.6.
34+ }
35+
36+ SmallVector<ArrayRef<spirv::Extension>, 1 > DotOp::getExtensions () {
37+ if (isa<BFloat16Type>(getType ())) {
38+ static const auto extension = spirv::Extension::SPV_KHR_bfloat16;
39+ return {extension};
40+ }
41+
42+ return {};
43+ }
44+
45+ SmallVector<ArrayRef<spirv::Capability>, 1 > DotOp::getCapabilities () {
46+ if (isa<BFloat16Type>(getType ())) {
47+ static const auto capability = spirv::Capability::BFloat16DotProductKHR;
48+ return {capability};
49+ }
50+
51+ return {};
52+ }
53+
54+ std::optional<spirv::Version> DotOp::getMinVersion () {
55+ return getDotProductMinVersion ();
56+ }
57+
58+ std::optional<spirv::Version> DotOp::getMaxVersion () {
59+ return getDotProductMaxVersion ();
60+ }
61+
2462// ===----------------------------------------------------------------------===//
2563// Integer Dot Product ops
2664// ===----------------------------------------------------------------------===//
@@ -71,14 +109,6 @@ static LogicalResult verifyIntegerDotProduct(Operation *op) {
71109 return success ();
72110}
73111
74- static std::optional<spirv::Version> getIntegerDotProductMinVersion () {
75- return spirv::Version::V_1_0; // Available in SPIR-V >= 1.0.
76- }
77-
78- static std::optional<spirv::Version> getIntegerDotProductMaxVersion () {
79- return spirv::Version::V_1_6; // Available in SPIR-V <= 1.6.
80- }
81-
82112static SmallVector<ArrayRef<spirv::Extension>, 1 >
83113getIntegerDotProductExtensions () {
84114 // Requires the SPV_KHR_integer_dot_product extension, specified either
@@ -136,10 +166,10 @@ getIntegerDotProductCapabilities(Operation *op) {
136166 return getIntegerDotProductCapabilities<OpName>(*this ); \
137167 } \
138168 std::optional<spirv::Version> OpName::getMinVersion () { \
139- return getIntegerDotProductMinVersion (); \
169+ return getDotProductMinVersion (); \
140170 } \
141171 std::optional<spirv::Version> OpName::getMaxVersion () { \
142- return getIntegerDotProductMaxVersion (); \
172+ return getDotProductMaxVersion (); \
143173 }
144174
145175SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP (SDotOp)
0 commit comments