Skip to content

Commit c2d667e

Browse files
committed
[mlir][spirv] Make CooperativeMatrixType a ShapedType
This is to enable `CooperativeMatrixType` to be used with `DenseElementsAttr`, so that a `spirv.Constant` can be easily built from `OpConstantComposite`. For example: ```mlir %cst = spirv.Constant dense<0.000000e+00> : !spirv.coopmatrix<1x1xf32, Subgroup, MatrixAcc> ``` Additional constraints are added to arithmetic operations, as `SameOperandsAndResultType` can no longer fully verify CoopMatrices. This is because for shaped types the verifier only checks element type and shapes, whereas for any other arbitrary type it looks for an exact match. This patch does not enable the actual deserialization.
1 parent 91a7085 commit c2d667e

File tree

4 files changed

+47
-13
lines changed

4 files changed

+47
-13
lines changed

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,21 @@ include "mlir/Dialect/SPIRV/IR/SPIRVBase.td"
1818
include "mlir/Interfaces/InferTypeOpInterface.td"
1919
include "mlir/Interfaces/SideEffectInterfaces.td"
2020

21+
class SPIRV_SameCoopMatrix<string lhs, string rhs> : PredOpTrait<
22+
"cooperative matrix types match",
23+
CPred<"(::llvm::isa<::mlir::spirv::CooperativeMatrixType>($" # lhs # ".getType()) "
24+
"&& ::llvm::isa<::mlir::spirv::CooperativeMatrixType>($" # rhs # ".getType()))"
25+
"? $" # lhs # ".getType() == $" # rhs # ".getType() : true">
26+
>;
27+
2128
class SPIRV_ArithmeticBinaryOp<string mnemonic, Type type,
2229
list<Trait> traits = []> :
2330
// Operands type same as result type.
2431
SPIRV_BinaryOp<mnemonic, type, type,
2532
!listconcat(traits,
26-
[Pure, SameOperandsAndResultType])> {
33+
[Pure, SameOperandsAndResultType,
34+
SPIRV_SameCoopMatrix<"operand1", "operand2">,
35+
SPIRV_SameCoopMatrix<"operand2", "result">])> {
2736
// In addition to normal types arithmetic instructions can support cooperative
2837
// matrix.
2938
let arguments = (ins
@@ -42,7 +51,8 @@ class SPIRV_ArithmeticUnaryOp<string mnemonic, Type type,
4251
// Operand type same as result type.
4352
SPIRV_UnaryOp<mnemonic, type, type,
4453
!listconcat(traits,
45-
[Pure, SameOperandsAndResultType])> {
54+
[Pure, SameOperandsAndResultType,
55+
SPIRV_SameCoopMatrix<"operand", "result">])> {
4656
// In addition to normal types arithmetic instructions can support cooperative
4757
// matrix.
4858
let arguments = (ins

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -394,7 +394,8 @@ hash_value(const StructType::MemberDecorationInfo &memberDecorationInfo);
394394
// SPIR-V KHR cooperative matrix type
395395
class CooperativeMatrixType
396396
: public Type::TypeBase<CooperativeMatrixType, CompositeType,
397-
detail::CooperativeMatrixTypeStorage> {
397+
detail::CooperativeMatrixTypeStorage,
398+
ShapedType::Trait> {
398399
public:
399400
using Base::Base;
400401

@@ -418,6 +419,23 @@ class CooperativeMatrixType
418419
std::optional<StorageClass> storage = std::nullopt);
419420
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
420421
std::optional<StorageClass> storage = std::nullopt);
422+
423+
operator ShapedType() const { return llvm::cast<ShapedType>(*this); }
424+
425+
ArrayRef<int64_t> getShape() const;
426+
427+
bool hasRank() const { return true; }
428+
429+
CooperativeMatrixType cloneWith(std::optional<ArrayRef<int64_t>> shape,
430+
Type elementType) const {
431+
if (shape == std::nullopt)
432+
return get(elementType, getRows(), getColumns(), getScope(), getUse());
433+
else {
434+
assert(shape.value().size() == 2);
435+
return get(elementType, shape.value()[0], shape.value()[1], getScope(),
436+
getUse());
437+
}
438+
}
421439
};
422440

423441
// SPIR-V matrix type

mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ std::optional<int64_t> CompositeType::getSizeInBytes() {
195195

196196
struct spirv::detail::CooperativeMatrixTypeStorage final : TypeStorage {
197197
using KeyTy =
198-
std::tuple<Type, uint32_t, uint32_t, Scope, CooperativeMatrixUseKHR>;
198+
std::tuple<Type, int64_t, int64_t, Scope, CooperativeMatrixUseKHR>;
199199

200200
static CooperativeMatrixTypeStorage *
201201
construct(TypeStorageAllocator &allocator, const KeyTy &key) {
@@ -204,17 +204,17 @@ struct spirv::detail::CooperativeMatrixTypeStorage final : TypeStorage {
204204
}
205205

206206
bool operator==(const KeyTy &key) const {
207-
return key == KeyTy(elementType, rows, columns, scope, use);
207+
return key == KeyTy(elementType, shape[0], shape[1], scope, use);
208208
}
209209

210210
CooperativeMatrixTypeStorage(const KeyTy &key)
211-
: elementType(std::get<0>(key)), rows(std::get<1>(key)),
212-
columns(std::get<2>(key)), scope(std::get<3>(key)),
211+
: elementType(std::get<0>(key)),
212+
shape({std::get<1>(key), std::get<2>(key)}), scope(std::get<3>(key)),
213213
use(std::get<4>(key)) {}
214214

215215
Type elementType;
216-
uint32_t rows;
217-
uint32_t columns;
216+
// [#rows, #columns]
217+
SmallVector<int64_t, 2> shape;
218218
Scope scope;
219219
CooperativeMatrixUseKHR use;
220220
};
@@ -231,10 +231,16 @@ Type CooperativeMatrixType::getElementType() const {
231231
return getImpl()->elementType;
232232
}
233233

234-
uint32_t CooperativeMatrixType::getRows() const { return getImpl()->rows; }
234+
uint32_t CooperativeMatrixType::getRows() const {
235+
return static_cast<uint32_t>(getImpl()->shape[0]);
236+
}
235237

236238
uint32_t CooperativeMatrixType::getColumns() const {
237-
return getImpl()->columns;
239+
return static_cast<uint32_t>(getImpl()->shape[1]);
240+
}
241+
242+
ArrayRef<int64_t> CooperativeMatrixType::getShape() const {
243+
return getImpl()->shape;
238244
}
239245

240246
Scope CooperativeMatrixType::getScope() const { return getImpl()->scope; }

mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -524,7 +524,7 @@ spirv.func @matrix_times_scalar(%a: !matA_f32, %b: f32) "None" {
524524

525525
spirv.func @iadd(%a: !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>,
526526
%b: !spirv.coopmatrix<2x2xi32, Subgroup, MatrixB>) "None" {
527-
// expected-error @+1 {{op requires the same type for all operands and results}}
527+
// expected-error @+1 {{op failed to verify that cooperative matrix types match}}
528528
%q = "spirv.IAdd"(%a, %b) :
529529
(!spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>, !spirv.coopmatrix<2x2xi32, Subgroup, MatrixB>)
530530
-> !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>
@@ -535,7 +535,7 @@ spirv.func @iadd(%a: !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>,
535535

536536
spirv.func @fadd(%a: !spirv.coopmatrix<2x2xf32, Subgroup, MatrixA>,
537537
%b: !spirv.coopmatrix<2x2xf32, Subgroup, MatrixAcc>) "None" {
538-
// expected-error @+1 {{op requires the same type for all operands and results}}
538+
// expected-error @+1 {{op failed to verify that cooperative matrix types match}}
539539
%q = "spirv.FAdd"(%a, %b) :
540540
(!spirv.coopmatrix<2x2xf32, Subgroup, MatrixA>, !spirv.coopmatrix<2x2xf32, Subgroup, MatrixAcc>)
541541
-> !spirv.coopmatrix<2x2xf32, Subgroup, MatrixAcc>

0 commit comments

Comments
 (0)