@@ -4234,8 +4234,13 @@ class SPIRV_CoopMatrixOfType<list<Type> allowedTypes> :
42344234 "::llvm::cast<::mlir::spirv::CooperativeMatrixType>($_self).getElementType()",
42354235 "Cooperative Matrix">;
42364236
4237+ class SPIRV_MatrixOfType<list<Type> allowedTypes> :
4238+ ContainerType<AnyTypeOf<allowedTypes>, SPIRV_IsMatrixType,
4239+ "::llvm::cast<::mlir::spirv::MatrixType>($_self).getElementType()",
4240+ "Matrix">;
4241+
42374242class SPIRV_VectorOf<Type type> :
4238- VectorOfLengthAndType<[2, 3, 4, 8,16], [type]>;
4243+ VectorOfLengthAndType<[2, 3, 4, 8, 16], [type]>;
42394244
42404245class SPIRV_ScalarOrVectorOf<Type type> :
42414246 AnyTypeOf<[type, SPIRV_VectorOf<type>]>;
@@ -4248,6 +4253,9 @@ class SPIRV_MatrixOrCoopMatrixOf<Type type> :
42484253 AnyTypeOf<[SPIRV_AnyMatrix,
42494254 SPIRV_CoopMatrixOfType<[type]>]>;
42504255
4256+ class SPIRV_MatrixOf<Type type> :
4257+ SPIRV_MatrixOfType<[type]>;
4258+
42514259def SPIRV_ScalarOrVector : AnyTypeOf<[SPIRV_Scalar, SPIRV_Vector]>;
42524260def SPIRV_ScalarOrVectorOrPtr : AnyTypeOf<[SPIRV_ScalarOrVector, SPIRV_AnyPtr]>;
42534261
@@ -4387,7 +4395,8 @@ def SPIRV_OC_OpFRem : I32EnumAttrCase<"OpFRem", 140>;
43874395def SPIRV_OC_OpFMod : I32EnumAttrCase<"OpFMod", 141>;
43884396def SPIRV_OC_OpVectorTimesScalar : I32EnumAttrCase<"OpVectorTimesScalar", 142>;
43894397def SPIRV_OC_OpMatrixTimesScalar : I32EnumAttrCase<"OpMatrixTimesScalar", 143>;
4390- def SPIRV_OC_OpMatrixTimesVector : I32EnumAttrCase<"OpMatrixTimesVector", 145>;
4398+ def SPIRV_OC_OpVectorTimesMatrix : I32EnumAttrCase<"OpVectorTimesMatrix", 144>;
4399+ def SPIRV_OC_OpMatrixTimesVector : I32EnumAttrCase<"OpMatrixTimesVector", 145>;
43914400def SPIRV_OC_OpMatrixTimesMatrix : I32EnumAttrCase<"OpMatrixTimesMatrix", 146>;
43924401def SPIRV_OC_OpDot : I32EnumAttrCase<"OpDot", 148>;
43934402def SPIRV_OC_OpIAddCarry : I32EnumAttrCase<"OpIAddCarry", 149>;
@@ -4559,7 +4568,8 @@ def SPIRV_OpcodeAttr :
45594568 SPIRV_OC_OpFSub, SPIRV_OC_OpIMul, SPIRV_OC_OpFMul, SPIRV_OC_OpUDiv,
45604569 SPIRV_OC_OpSDiv, SPIRV_OC_OpFDiv, SPIRV_OC_OpUMod, SPIRV_OC_OpSRem,
45614570 SPIRV_OC_OpSMod, SPIRV_OC_OpFRem, SPIRV_OC_OpFMod,
4562- SPIRV_OC_OpVectorTimesScalar, SPIRV_OC_OpMatrixTimesScalar, SPIRV_OC_OpMatrixTimesVector,
4571+ SPIRV_OC_OpVectorTimesScalar, SPIRV_OC_OpMatrixTimesScalar,
4572+ SPIRV_OC_OpVectorTimesMatrix, SPIRV_OC_OpMatrixTimesVector,
45634573 SPIRV_OC_OpMatrixTimesMatrix, SPIRV_OC_OpDot, SPIRV_OC_OpIAddCarry,
45644574 SPIRV_OC_OpISubBorrow, SPIRV_OC_OpUMulExtended, SPIRV_OC_OpSMulExtended,
45654575 SPIRV_OC_OpIsNan, SPIRV_OC_OpIsInf, SPIRV_OC_OpOrdered, SPIRV_OC_OpUnordered,
0 commit comments