Skip to content

Commit 760787b

Browse files
committed
Address feedback
1 parent c2d667e commit 760787b

File tree

4 files changed

+23
-21
lines changed

4 files changed

+23
-21
lines changed

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,21 +18,12 @@ include "mlir/Dialect/SPIRV/IR/SPIRVBase.td"
1818
include "mlir/Interfaces/InferTypeOpInterface.td"
1919
include "mlir/Interfaces/SideEffectInterfaces.td"
2020

21-
class SPIRV_SameCoopMatrix<string lhs, string rhs> : PredOpTrait<
22-
"cooperative matrix types match",
23-
CPred<"(::llvm::isa<::mlir::spirv::CooperativeMatrixType>($" # lhs # ".getType()) "
24-
"&& ::llvm::isa<::mlir::spirv::CooperativeMatrixType>($" # rhs # ".getType()))"
25-
"? $" # lhs # ".getType() == $" # rhs # ".getType() : true">
26-
>;
27-
2821
class SPIRV_ArithmeticBinaryOp<string mnemonic, Type type,
2922
list<Trait> traits = []> :
3023
// Operands type same as result type.
3124
SPIRV_BinaryOp<mnemonic, type, type,
3225
!listconcat(traits,
33-
[Pure, SameOperandsAndResultType,
34-
SPIRV_SameCoopMatrix<"operand1", "operand2">,
35-
SPIRV_SameCoopMatrix<"operand2", "result">])> {
26+
[Pure, AllTypesMatch<["operand1", "operand2", "result"]>])> {
3627
// In addition to normal types arithmetic instructions can support cooperative
3728
// matrix.
3829
let arguments = (ins
@@ -51,8 +42,7 @@ class SPIRV_ArithmeticUnaryOp<string mnemonic, Type type,
5142
// Operand type same as result type.
5243
SPIRV_UnaryOp<mnemonic, type, type,
5344
!listconcat(traits,
54-
[Pure, SameOperandsAndResultType,
55-
SPIRV_SameCoopMatrix<"operand", "result">])> {
45+
[Pure, AllTypesMatch<["operand", "result"]>])> {
5646
// In addition to normal types arithmetic instructions can support cooperative
5747
// matrix.
5848
let arguments = (ins

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -428,13 +428,12 @@ class CooperativeMatrixType
428428

429429
CooperativeMatrixType cloneWith(std::optional<ArrayRef<int64_t>> shape,
430430
Type elementType) const {
431-
if (shape == std::nullopt)
431+
if (!shape)
432432
return get(elementType, getRows(), getColumns(), getScope(), getUse());
433-
else {
434-
assert(shape.value().size() == 2);
435-
return get(elementType, shape.value()[0], shape.value()[1], getScope(),
436-
getUse());
437-
}
433+
434+
assert(shape.value().size() == 2);
435+
return get(elementType, shape.value()[0], shape.value()[1], getScope(),
436+
getUse());
438437
}
439438
};
440439

mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,19 @@ std::optional<int64_t> CompositeType::getSizeInBytes() {
194194
//===----------------------------------------------------------------------===//
195195

196196
struct spirv::detail::CooperativeMatrixTypeStorage final : TypeStorage {
197+
// In the specification dimensions of the Cooperative Matrix are 32-bit
198+
// integers --- the initial implementation kept those values as such. However,
199+
// the `ShapedType` expects the shape to be `int64_t`. We could keep the shape
200+
// as 32-bits and expose it as int64_t through `getShape`, however, this
201+
// method returns an `ArrayRef`, so returning `ArrayRef<int64_t>` having two
202+
// 32-bits integers would require an extra logic and storage. So, we diverge
203+
// from the spec and internally represent the dimensions as 64-bit integers,
204+
// so we can easily return an `ArrayRef` from `getShape` without any extra
205+
// logic. Alternatively, we could store both rows and columns (both 32-bits)
206+
// and shape (64-bits), assigning rows and columns to shape whenever
207+
// `getShape` is called. This would be at the cost of extra logic and storage.
208+
// Note: Because `ArrayRef` is returned we cannot construct an object in
209+
// `getShape` on the fly.
197210
using KeyTy =
198211
std::tuple<Type, int64_t, int64_t, Scope, CooperativeMatrixUseKHR>;
199212

@@ -214,7 +227,7 @@ struct spirv::detail::CooperativeMatrixTypeStorage final : TypeStorage {
214227

215228
Type elementType;
216229
// [#rows, #columns]
217-
SmallVector<int64_t, 2> shape;
230+
std::array<int64_t, 2> shape;
218231
Scope scope;
219232
CooperativeMatrixUseKHR use;
220233
};

mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -524,7 +524,7 @@ spirv.func @matrix_times_scalar(%a: !matA_f32, %b: f32) "None" {
524524

525525
spirv.func @iadd(%a: !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>,
526526
%b: !spirv.coopmatrix<2x2xi32, Subgroup, MatrixB>) "None" {
527-
// expected-error @+1 {{op failed to verify that cooperative matrix types match}}
527+
// expected-error @+1 {{failed to verify that all of {operand1, operand2, result} have same type}}
528528
%q = "spirv.IAdd"(%a, %b) :
529529
(!spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>, !spirv.coopmatrix<2x2xi32, Subgroup, MatrixB>)
530530
-> !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>
@@ -535,7 +535,7 @@ spirv.func @iadd(%a: !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>,
535535

536536
spirv.func @fadd(%a: !spirv.coopmatrix<2x2xf32, Subgroup, MatrixA>,
537537
%b: !spirv.coopmatrix<2x2xf32, Subgroup, MatrixAcc>) "None" {
538-
// expected-error @+1 {{op failed to verify that cooperative matrix types match}}
538+
// expected-error @+1 {{failed to verify that all of {operand1, operand2, result} have same type}}
539539
%q = "spirv.FAdd"(%a, %b) :
540540
(!spirv.coopmatrix<2x2xf32, Subgroup, MatrixA>, !spirv.coopmatrix<2x2xf32, Subgroup, MatrixAcc>)
541541
-> !spirv.coopmatrix<2x2xf32, Subgroup, MatrixAcc>

0 commit comments

Comments
 (0)