Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class SPIRV_ArithmeticBinaryOp<string mnemonic, Type type,
// Operands type same as result type.
SPIRV_BinaryOp<mnemonic, type, type,
!listconcat(traits,
[Pure, SameOperandsAndResultType])> {
[Pure, AllTypesMatch<["operand1", "operand2", "result"]>])> {
// In addition to normal types arithmetic instructions can support cooperative
// matrix.
let arguments = (ins
Expand All @@ -42,7 +42,7 @@ class SPIRV_ArithmeticUnaryOp<string mnemonic, Type type,
// Operand type same as result type.
SPIRV_UnaryOp<mnemonic, type, type,
!listconcat(traits,
[Pure, SameOperandsAndResultType])> {
[Pure, AllTypesMatch<["operand", "result"]>])> {
// In addition to normal types arithmetic instructions can support cooperative
// matrix.
let arguments = (ins
Expand Down
19 changes: 18 additions & 1 deletion mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,8 @@ hash_value(const StructType::MemberDecorationInfo &memberDecorationInfo);
// SPIR-V KHR cooperative matrix type
class CooperativeMatrixType
: public Type::TypeBase<CooperativeMatrixType, CompositeType,
detail::CooperativeMatrixTypeStorage> {
detail::CooperativeMatrixTypeStorage,
ShapedType::Trait> {
public:
using Base::Base;

Expand All @@ -418,6 +419,22 @@ class CooperativeMatrixType
std::optional<StorageClass> storage = std::nullopt);
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
std::optional<StorageClass> storage = std::nullopt);

operator ShapedType() const { return llvm::cast<ShapedType>(*this); }

ArrayRef<int64_t> getShape() const;

bool hasRank() const { return true; }

CooperativeMatrixType cloneWith(std::optional<ArrayRef<int64_t>> shape,
Type elementType) const {
if (!shape)
return get(elementType, getRows(), getColumns(), getScope(), getUse());

assert(shape.value().size() == 2);
return get(elementType, shape.value()[0], shape.value()[1], getScope(),
getUse());
}
};

// SPIR-V matrix type
Expand Down
35 changes: 27 additions & 8 deletions mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -194,8 +194,21 @@ std::optional<int64_t> CompositeType::getSizeInBytes() {
//===----------------------------------------------------------------------===//

struct spirv::detail::CooperativeMatrixTypeStorage final : TypeStorage {
// In the specification dimensions of the Cooperative Matrix are 32-bit
// integers --- the initial implementation kept those values as such. However,
// the `ShapedType` expects the shape to be `int64_t`. We could keep the shape
// as 32-bits and expose it as int64_t through `getShape`, however, this
// method returns an `ArrayRef`, so returning `ArrayRef<int64_t>` having two
// 32-bits integers would require an extra logic and storage. So, we diverge
// from the spec and internally represent the dimensions as 64-bit integers,
// so we can easily return an `ArrayRef` from `getShape` without any extra
// logic. Alternatively, we could store both rows and columns (both 32-bits)
// and shape (64-bits), assigning rows and columns to shape whenever
// `getShape` is called. This would be at the cost of extra logic and storage.
// Note: Because `ArrayRef` is returned we cannot construct an object in
// `getShape` on the fly.
using KeyTy =
std::tuple<Type, uint32_t, uint32_t, Scope, CooperativeMatrixUseKHR>;
std::tuple<Type, int64_t, int64_t, Scope, CooperativeMatrixUseKHR>;

static CooperativeMatrixTypeStorage *
construct(TypeStorageAllocator &allocator, const KeyTy &key) {
Expand All @@ -204,17 +217,17 @@ struct spirv::detail::CooperativeMatrixTypeStorage final : TypeStorage {
}

bool operator==(const KeyTy &key) const {
return key == KeyTy(elementType, rows, columns, scope, use);
return key == KeyTy(elementType, shape[0], shape[1], scope, use);
}

CooperativeMatrixTypeStorage(const KeyTy &key)
: elementType(std::get<0>(key)), rows(std::get<1>(key)),
columns(std::get<2>(key)), scope(std::get<3>(key)),
: elementType(std::get<0>(key)),
shape({std::get<1>(key), std::get<2>(key)}), scope(std::get<3>(key)),
use(std::get<4>(key)) {}

Type elementType;
uint32_t rows;
uint32_t columns;
// [#rows, #columns]
std::array<int64_t, 2> shape;
Scope scope;
CooperativeMatrixUseKHR use;
};
Expand All @@ -231,10 +244,16 @@ Type CooperativeMatrixType::getElementType() const {
return getImpl()->elementType;
}

uint32_t CooperativeMatrixType::getRows() const { return getImpl()->rows; }
uint32_t CooperativeMatrixType::getRows() const {
return static_cast<uint32_t>(getImpl()->shape[0]);
}

uint32_t CooperativeMatrixType::getColumns() const {
return getImpl()->columns;
return static_cast<uint32_t>(getImpl()->shape[1]);
}

ArrayRef<int64_t> CooperativeMatrixType::getShape() const {
return getImpl()->shape;
}

Scope CooperativeMatrixType::getScope() const { return getImpl()->scope; }
Expand Down
4 changes: 2 additions & 2 deletions mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -524,7 +524,7 @@ spirv.func @matrix_times_scalar(%a: !matA_f32, %b: f32) "None" {

spirv.func @iadd(%a: !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>,
%b: !spirv.coopmatrix<2x2xi32, Subgroup, MatrixB>) "None" {
// expected-error @+1 {{op requires the same type for all operands and results}}
// expected-error @+1 {{failed to verify that all of {operand1, operand2, result} have same type}}
%q = "spirv.IAdd"(%a, %b) :
(!spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>, !spirv.coopmatrix<2x2xi32, Subgroup, MatrixB>)
-> !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>
Expand All @@ -535,7 +535,7 @@ spirv.func @iadd(%a: !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>,

spirv.func @fadd(%a: !spirv.coopmatrix<2x2xf32, Subgroup, MatrixA>,
%b: !spirv.coopmatrix<2x2xf32, Subgroup, MatrixAcc>) "None" {
// expected-error @+1 {{op requires the same type for all operands and results}}
// expected-error @+1 {{failed to verify that all of {operand1, operand2, result} have same type}}
%q = "spirv.FAdd"(%a, %b) :
(!spirv.coopmatrix<2x2xf32, Subgroup, MatrixA>, !spirv.coopmatrix<2x2xf32, Subgroup, MatrixAcc>)
-> !spirv.coopmatrix<2x2xf32, Subgroup, MatrixAcc>
Expand Down
Loading