Skip to content

Commit e51870f

Browse files
committed
Add availability generation and test
1 parent 71755ac commit e51870f

File tree

4 files changed

+60
-13
lines changed

4 files changed

+60
-13
lines changed

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -455,6 +455,9 @@ def SPIRV_DotOp : SPIRV_Op<"Dot",
455455

456456
let assemblyFormat = "operands attr-dict `:` type($vector1) `->` type($result)";
457457

458+
// Require dynamic availability specification based on operand/result type.
459+
bit autogenAvailability = 0;
460+
458461
let hasVerifier = 0;
459462
}
460463

mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@ add_mlir_dialect_library(MLIRSPIRVDialect
77
CastOps.cpp
88
ControlFlowOps.cpp
99
CooperativeMatrixOps.cpp
10+
DotProductOps.cpp
1011
GroupOps.cpp
1112
ImageOps.cpp
12-
IntegerDotProductOps.cpp
1313
MemoryOps.cpp
1414
MeshOps.cpp
1515
SPIRVAttributes.cpp

mlir/lib/Dialect/SPIRV/IR/IntegerDotProductOps.cpp renamed to mlir/lib/Dialect/SPIRV/IR/DotProductOps.cpp

Lines changed: 42 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
1-
//===- IntegerDotProductOps.cpp - MLIR SPIR-V Integer Dot Product Ops ----===//
1+
//===- DotProductOps.cpp - MLIR SPIR-V Dot Product Ops ----===//
22
//
33
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44
// See https://llvm.org/LICENSE.txt for license information.
55
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66
//
77
//===----------------------------------------------------------------------===//
88
//
9-
// Defines the Integer Dot Product operations in the SPIR-V dialect.
9+
// Defines the Dot Product operations in the SPIR-V dialect.
1010
//
1111
//===----------------------------------------------------------------------===//
1212

@@ -21,6 +21,44 @@ using namespace mlir::spirv::AttrNames;
2121

2222
namespace mlir::spirv {
2323

24+
//===----------------------------------------------------------------------===//
25+
// Dot Product ops
26+
//===----------------------------------------------------------------------===//
27+
28+
static std::optional<spirv::Version> getDotProductMinVersion() {
29+
return spirv::Version::V_1_0; // Available in SPIR-V >= 1.0.
30+
}
31+
32+
static std::optional<spirv::Version> getDotProductMaxVersion() {
33+
return spirv::Version::V_1_6; // Available in SPIR-V <= 1.6.
34+
}
35+
36+
SmallVector<ArrayRef<spirv::Extension>, 1> DotOp::getExtensions() {
37+
if (getResult().getType().isBF16()) {
38+
static const auto extension = spirv::Extension::SPV_KHR_bfloat16;
39+
return {extension};
40+
}
41+
42+
return {};
43+
}
44+
45+
SmallVector<ArrayRef<spirv::Capability>, 1> DotOp::getCapabilities() {
46+
if (getResult().getType().isBF16()) {
47+
static const auto capability = spirv::Capability::BFloat16DotProductKHR;
48+
return {capability};
49+
}
50+
51+
return {};
52+
}
53+
54+
std::optional<spirv::Version> DotOp::getMinVersion() {
55+
return getDotProductMinVersion();
56+
}
57+
58+
std::optional<spirv::Version> DotOp::getMaxVersion() {
59+
return getDotProductMaxVersion();
60+
}
61+
2462
//===----------------------------------------------------------------------===//
2563
// Integer Dot Product ops
2664
//===----------------------------------------------------------------------===//
@@ -71,14 +109,6 @@ static LogicalResult verifyIntegerDotProduct(Operation *op) {
71109
return success();
72110
}
73111

74-
static std::optional<spirv::Version> getIntegerDotProductMinVersion() {
75-
return spirv::Version::V_1_0; // Available in SPIR-V >= 1.0.
76-
}
77-
78-
static std::optional<spirv::Version> getIntegerDotProductMaxVersion() {
79-
return spirv::Version::V_1_6; // Available in SPIR-V <= 1.6.
80-
}
81-
82112
static SmallVector<ArrayRef<spirv::Extension>, 1>
83113
getIntegerDotProductExtensions() {
84114
// Requires the SPV_KHR_integer_dot_product extension, specified either
@@ -136,10 +166,10 @@ getIntegerDotProductCapabilities(Operation *op) {
136166
return getIntegerDotProductCapabilities<OpName>(*this); \
137167
} \
138168
std::optional<spirv::Version> OpName::getMinVersion() { \
139-
return getIntegerDotProductMinVersion(); \
169+
return getDotProductMinVersion(); \
140170
} \
141171
std::optional<spirv::Version> OpName::getMaxVersion() { \
142-
return getIntegerDotProductMaxVersion(); \
172+
return getDotProductMaxVersion(); \
143173
}
144174

145175
SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(SDotOp)

mlir/test/Dialect/SPIRV/IR/availability.mlir

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,20 @@ func.func @udot_acc_sat_vector_4xi16_i64(%a: vector<4xi16>, %acc: i64) -> i64 {
234234
return %r: i64
235235
}
236236

237+
//===----------------------------------------------------------------------===//
238+
// Dot Product op with bfloat16
239+
//===----------------------------------------------------------------------===//
240+
241+
// CHECK-LABEL: dot_vector_4xbf16_bf16
242+
func.func @dot_vector_4xbf16_bf16(%a: vector<4xbf16>, %b: vector<4xbf16>) -> bf16 {
243+
// CHECK: min version: v1.0
244+
// CHECK: max version: v1.6
245+
// CHECK: extensions: [ [SPV_KHR_bfloat16] ]
246+
// CHECK: capabilities: [ [BFloat16DotProductKHR] ]
247+
%r = spirv.Dot %a, %a: vector<4xbf16> -> bf16
248+
return %r: bf16
249+
}
250+
237251
//===----------------------------------------------------------------------===//
238252
// Primitive ops
239253
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)