From c2d667e440996bac737bd043fdc7be352c031d0a Mon Sep 17 00:00:00 2001 From: Igor Wodiany Date: Tue, 20 May 2025 17:42:12 +0100 Subject: [PATCH 1/3] [mlir][spirv] Make `CooperativeMatrixType` a `ShapedType` This is to enable `CooperativeMatrixType` to be used with `DenseElementsAttr`, so that a `spirv.Constant` can be easily built from `OpConstantComposite`. For example: ```mlir %cst = spirv.Constant dense<0.000000e+00> : !spirv.coopmatrix<1x1xf32, Subgroup, MatrixAcc> ``` Additional constraints are added to arithmetic operations, as `SameOperandsAndResultType` can no longer fully verify CoopMatrices. This is because for shaped types the verifier only checks element type and shapes, whereas for any other arbitrary type it looks for an exact match. This patch does not enable the actual deserialization. --- .../Dialect/SPIRV/IR/SPIRVArithmeticOps.td | 14 ++++++++++-- .../mlir/Dialect/SPIRV/IR/SPIRVTypes.h | 20 ++++++++++++++++- mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp | 22 ++++++++++++------- .../SPIRV/IR/khr-cooperative-matrix-ops.mlir | 4 ++-- 4 files changed, 47 insertions(+), 13 deletions(-) diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td index 22d5afcd77381..48f525e048e60 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td @@ -18,12 +18,21 @@ include "mlir/Dialect/SPIRV/IR/SPIRVBase.td" include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" +class SPIRV_SameCoopMatrix : PredOpTrait< + "cooperative matrix types match", + CPred<"(::llvm::isa<::mlir::spirv::CooperativeMatrixType>($" # lhs # ".getType()) " + "&& ::llvm::isa<::mlir::spirv::CooperativeMatrixType>($" # rhs # ".getType()))" + "? $" # lhs # ".getType() == $" # rhs # ".getType() : true"> +>; + class SPIRV_ArithmeticBinaryOp traits = []> : // Operands type same as result type. SPIRV_BinaryOp { + [Pure, SameOperandsAndResultType, + SPIRV_SameCoopMatrix<"operand1", "operand2">, + SPIRV_SameCoopMatrix<"operand2", "result">])> { // In addition to normal types arithmetic instructions can support cooperative // matrix. let arguments = (ins @@ -42,7 +51,8 @@ class SPIRV_ArithmeticUnaryOp { + [Pure, SameOperandsAndResultType, + SPIRV_SameCoopMatrix<"operand", "result">])> { // In addition to normal types arithmetic instructions can support cooperative // matrix. let arguments = (ins diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h index 2e29e9afaabf4..a7b6569245dd5 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h @@ -394,7 +394,8 @@ hash_value(const StructType::MemberDecorationInfo &memberDecorationInfo); // SPIR-V KHR cooperative matrix type class CooperativeMatrixType : public Type::TypeBase { + detail::CooperativeMatrixTypeStorage, + ShapedType::Trait> { public: using Base::Base; @@ -418,6 +419,23 @@ class CooperativeMatrixType std::optional storage = std::nullopt); void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional storage = std::nullopt); + + operator ShapedType() const { return llvm::cast(*this); } + + ArrayRef getShape() const; + + bool hasRank() const { return true; } + + CooperativeMatrixType cloneWith(std::optional> shape, + Type elementType) const { + if (shape == std::nullopt) + return get(elementType, getRows(), getColumns(), getScope(), getUse()); + else { + assert(shape.value().size() == 2); + return get(elementType, shape.value()[0], shape.value()[1], getScope(), + getUse()); + } + } }; // SPIR-V matrix type diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp index 337df3a5a65f0..de2034680cd5f 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp @@ -195,7 +195,7 @@ std::optional CompositeType::getSizeInBytes() { struct spirv::detail::CooperativeMatrixTypeStorage final : TypeStorage { using KeyTy = - std::tuple; + std::tuple; static CooperativeMatrixTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key) { @@ -204,17 +204,17 @@ struct spirv::detail::CooperativeMatrixTypeStorage final : TypeStorage { } bool operator==(const KeyTy &key) const { - return key == KeyTy(elementType, rows, columns, scope, use); + return key == KeyTy(elementType, shape[0], shape[1], scope, use); } CooperativeMatrixTypeStorage(const KeyTy &key) - : elementType(std::get<0>(key)), rows(std::get<1>(key)), - columns(std::get<2>(key)), scope(std::get<3>(key)), + : elementType(std::get<0>(key)), + shape({std::get<1>(key), std::get<2>(key)}), scope(std::get<3>(key)), use(std::get<4>(key)) {} Type elementType; - uint32_t rows; - uint32_t columns; + // [#rows, #columns] + SmallVector shape; Scope scope; CooperativeMatrixUseKHR use; }; @@ -231,10 +231,16 @@ Type CooperativeMatrixType::getElementType() const { return getImpl()->elementType; } -uint32_t CooperativeMatrixType::getRows() const { return getImpl()->rows; } +uint32_t CooperativeMatrixType::getRows() const { + return static_cast(getImpl()->shape[0]); +} uint32_t CooperativeMatrixType::getColumns() const { - return getImpl()->columns; + return static_cast(getImpl()->shape[1]); +} + +ArrayRef CooperativeMatrixType::getShape() const { + return getImpl()->shape; } Scope CooperativeMatrixType::getScope() const { return getImpl()->scope; } diff --git a/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir b/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir index d3e1dbc229ef9..4ae8b70bf43ca 100644 --- a/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir @@ -524,7 +524,7 @@ spirv.func @matrix_times_scalar(%a: !matA_f32, %b: f32) "None" { spirv.func @iadd(%a: !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>, %b: !spirv.coopmatrix<2x2xi32, Subgroup, MatrixB>) "None" { - // expected-error @+1 {{op requires the same type for all operands and results}} + // expected-error @+1 {{op failed to verify that cooperative matrix types match}} %q = "spirv.IAdd"(%a, %b) : (!spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>, !spirv.coopmatrix<2x2xi32, Subgroup, MatrixB>) -> !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA> @@ -535,7 +535,7 @@ spirv.func @iadd(%a: !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>, spirv.func @fadd(%a: !spirv.coopmatrix<2x2xf32, Subgroup, MatrixA>, %b: !spirv.coopmatrix<2x2xf32, Subgroup, MatrixAcc>) "None" { - // expected-error @+1 {{op requires the same type for all operands and results}} + // expected-error @+1 {{op failed to verify that cooperative matrix types match}} %q = "spirv.FAdd"(%a, %b) : (!spirv.coopmatrix<2x2xf32, Subgroup, MatrixA>, !spirv.coopmatrix<2x2xf32, Subgroup, MatrixAcc>) -> !spirv.coopmatrix<2x2xf32, Subgroup, MatrixAcc> From 760787b2496a234363b0d162074c100d1398a2f3 Mon Sep 17 00:00:00 2001 From: Igor Wodiany Date: Mon, 9 Jun 2025 14:23:29 +0100 Subject: [PATCH 2/3] Address feedback --- .../mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td | 14 ++------------ mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h | 11 +++++------ mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp | 15 ++++++++++++++- .../SPIRV/IR/khr-cooperative-matrix-ops.mlir | 4 ++-- 4 files changed, 23 insertions(+), 21 deletions(-) diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td index 48f525e048e60..309079e549846 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td @@ -18,21 +18,12 @@ include "mlir/Dialect/SPIRV/IR/SPIRVBase.td" include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" -class SPIRV_SameCoopMatrix : PredOpTrait< - "cooperative matrix types match", - CPred<"(::llvm::isa<::mlir::spirv::CooperativeMatrixType>($" # lhs # ".getType()) " - "&& ::llvm::isa<::mlir::spirv::CooperativeMatrixType>($" # rhs # ".getType()))" - "? $" # lhs # ".getType() == $" # rhs # ".getType() : true"> ->; - class SPIRV_ArithmeticBinaryOp traits = []> : // Operands type same as result type. SPIRV_BinaryOp, - SPIRV_SameCoopMatrix<"operand2", "result">])> { + [Pure, AllTypesMatch<["operand1", "operand2", "result"]>])> { // In addition to normal types arithmetic instructions can support cooperative // matrix. let arguments = (ins @@ -51,8 +42,7 @@ class SPIRV_ArithmeticUnaryOp])> { + [Pure, AllTypesMatch<["operand", "result"]>])> { // In addition to normal types arithmetic instructions can support cooperative // matrix. let arguments = (ins diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h index a7b6569245dd5..787535d0a6bd2 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h @@ -428,13 +428,12 @@ class CooperativeMatrixType CooperativeMatrixType cloneWith(std::optional> shape, Type elementType) const { - if (shape == std::nullopt) + if (!shape) return get(elementType, getRows(), getColumns(), getScope(), getUse()); - else { - assert(shape.value().size() == 2); - return get(elementType, shape.value()[0], shape.value()[1], getScope(), - getUse()); - } + + assert(shape.value().size() == 2); + return get(elementType, shape.value()[0], shape.value()[1], getScope(), + getUse()); } }; diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp index de2034680cd5f..2ed78db52c87a 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp @@ -194,6 +194,19 @@ std::optional CompositeType::getSizeInBytes() { //===----------------------------------------------------------------------===// struct spirv::detail::CooperativeMatrixTypeStorage final : TypeStorage { + // In the specification dimensions of the Cooperative Matrix are 32-bit + // integers --- the initial implementation kept those values as such. However, + // the `ShapedType` expects the shape to be `int64_t`. We could keep the shape + // as 32-bits and expose it as int64_t through `getShape`, however, this + // method returns an `ArrayRef`, so returning `ArrayRef` having two + // 32-bits integers would require an extra logic and storage. So, we diverge + // from the spec and internally represent the dimensions as 64-bit integers, + // so we can easily return an `ArrayRef` from `getShape` without any extra + // logic. Alternatively, we could store both rows and columns (both 32-bits) + // and shape (64-bits), assigning rows and columns to shape whenever + // `getShape` is called. This would be at the cost of extra logic and storage. + // Note: Because `ArrayRef` is returned we cannot construct an object in + // `getShape` on the fly. using KeyTy = std::tuple; @@ -214,7 +227,7 @@ struct spirv::detail::CooperativeMatrixTypeStorage final : TypeStorage { Type elementType; // [#rows, #columns] - SmallVector shape; + std::array shape; Scope scope; CooperativeMatrixUseKHR use; }; diff --git a/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir b/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir index 4ae8b70bf43ca..8733ff93768ab 100644 --- a/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir @@ -524,7 +524,7 @@ spirv.func @matrix_times_scalar(%a: !matA_f32, %b: f32) "None" { spirv.func @iadd(%a: !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>, %b: !spirv.coopmatrix<2x2xi32, Subgroup, MatrixB>) "None" { - // expected-error @+1 {{op failed to verify that cooperative matrix types match}} + // expected-error @+1 {{failed to verify that all of {operand1, operand2, result} have same type}} %q = "spirv.IAdd"(%a, %b) : (!spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>, !spirv.coopmatrix<2x2xi32, Subgroup, MatrixB>) -> !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA> @@ -535,7 +535,7 @@ spirv.func @iadd(%a: !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>, spirv.func @fadd(%a: !spirv.coopmatrix<2x2xf32, Subgroup, MatrixA>, %b: !spirv.coopmatrix<2x2xf32, Subgroup, MatrixAcc>) "None" { - // expected-error @+1 {{op failed to verify that cooperative matrix types match}} + // expected-error @+1 {{failed to verify that all of {operand1, operand2, result} have same type}} %q = "spirv.FAdd"(%a, %b) : (!spirv.coopmatrix<2x2xf32, Subgroup, MatrixA>, !spirv.coopmatrix<2x2xf32, Subgroup, MatrixAcc>) -> !spirv.coopmatrix<2x2xf32, Subgroup, MatrixAcc> From ef0629545f9556c75d9b839a609a78ca54cd4e38 Mon Sep 17 00:00:00 2001 From: Igor Wodiany Date: Mon, 9 Jun 2025 15:27:00 +0100 Subject: [PATCH 3/3] Add asserts --- mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp index 2ed78db52c87a..1aff43c301334 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp @@ -245,10 +245,12 @@ Type CooperativeMatrixType::getElementType() const { } uint32_t CooperativeMatrixType::getRows() const { + assert(getImpl()->shape[0] != ShapedType::kDynamic); return static_cast(getImpl()->shape[0]); } uint32_t CooperativeMatrixType::getColumns() const { + assert(getImpl()->shape[1] != ShapedType::kDynamic); return static_cast(getImpl()->shape[1]); }