@@ -91,6 +91,8 @@ Type VulkanLayoutUtils::decorateType(Type type, VulkanLayoutUtils::Size &size,
9191 return decorateType (arrayType, size, alignment);
9292 if (auto vectorType = dyn_cast<VectorType>(type))
9393 return decorateType (vectorType, size, alignment);
94+ if (auto matrixType = dyn_cast<spirv::MatrixType>(type))
95+ return decorateType (matrixType, size, alignment);
9496 if (auto arrayType = dyn_cast<spirv::RuntimeArrayType>(type)) {
9597 size = std::numeric_limits<Size>().max ();
9698 return decorateType (arrayType, alignment);
@@ -138,6 +140,25 @@ Type VulkanLayoutUtils::decorateType(spirv::ArrayType arrayType,
138140 return spirv::ArrayType::get (memberType, numElements, elementSize);
139141}
140142
143+ Type VulkanLayoutUtils::decorateType (spirv::MatrixType matrixType,
144+ VulkanLayoutUtils::Size &size,
145+ VulkanLayoutUtils::Size &alignment) {
146+ const auto numColumns = matrixType.getNumColumns ();
147+ const auto columnType = matrixType.getColumnType ();
148+ const auto numElements = matrixType.getNumElements ();
149+ auto elementType = matrixType.getElementType ();
150+ Size elementSize = 0 ;
151+ Size elementAlignment = 1 ;
152+
153+ decorateType (elementType, elementSize, elementAlignment);
154+ // According to the Vulkan spec:
155+ // "A matrix type inherits scalar alignment from the equivalent array
156+ // declaration.
157+ size = elementSize * numElements;
158+ alignment = elementAlignment;
159+ return spirv::MatrixType::get (columnType, numColumns);
160+ }
161+
141162Type VulkanLayoutUtils::decorateType (spirv::RuntimeArrayType arrayType,
142163 VulkanLayoutUtils::Size &alignment) {
143164 auto elementType = arrayType.getElementType ();
0 commit comments