diff --git a/mlir/include/mlir/Dialect/SPIRV/Utils/LayoutUtils.h b/mlir/include/mlir/Dialect/SPIRV/Utils/LayoutUtils.h index 0c61f7eb54e2d..72683d50d7411 100644 --- a/mlir/include/mlir/Dialect/SPIRV/Utils/LayoutUtils.h +++ b/mlir/include/mlir/Dialect/SPIRV/Utils/LayoutUtils.h @@ -24,6 +24,7 @@ namespace spirv { class ArrayType; class RuntimeArrayType; class StructType; +class MatrixType; } // namespace spirv /// According to the Vulkan spec "15.6.4. Offset and Stride Assignment": @@ -67,6 +68,8 @@ class VulkanLayoutUtils { static Type decorateType(VectorType vectorType, Size &size, Size &alignment); static Type decorateType(spirv::ArrayType arrayType, Size &size, Size &alignment); + static Type decorateType(spirv::MatrixType matrixType, Size &size, + Size &alignment); static Type decorateType(spirv::RuntimeArrayType arrayType, Size &alignment); static spirv::StructType decorateType(spirv::StructType structType, Size &size, Size &alignment); diff --git a/mlir/lib/Dialect/SPIRV/Utils/LayoutUtils.cpp b/mlir/lib/Dialect/SPIRV/Utils/LayoutUtils.cpp index b19495bc37445..51cfe4a68eb2d 100644 --- a/mlir/lib/Dialect/SPIRV/Utils/LayoutUtils.cpp +++ b/mlir/lib/Dialect/SPIRV/Utils/LayoutUtils.cpp @@ -91,6 +91,8 @@ Type VulkanLayoutUtils::decorateType(Type type, VulkanLayoutUtils::Size &size, return decorateType(arrayType, size, alignment); if (auto vectorType = dyn_cast(type)) return decorateType(vectorType, size, alignment); + if (auto matrixType = dyn_cast(type)) + return decorateType(matrixType, size, alignment); if (auto arrayType = dyn_cast(type)) { size = std::numeric_limits().max(); return decorateType(arrayType, alignment); @@ -138,6 +140,25 @@ Type VulkanLayoutUtils::decorateType(spirv::ArrayType arrayType, return spirv::ArrayType::get(memberType, numElements, elementSize); } +Type VulkanLayoutUtils::decorateType(spirv::MatrixType matrixType, + VulkanLayoutUtils::Size &size, + VulkanLayoutUtils::Size &alignment) { + const unsigned numColumns = matrixType.getNumColumns(); + Type columnType = matrixType.getColumnType(); + unsigned numElements = matrixType.getNumElements(); + Type elementType = matrixType.getElementType(); + Size elementSize = 0; + Size elementAlignment = 1; + + decorateType(elementType, elementSize, elementAlignment); + // According to the Vulkan spec: + // "A matrix type inherits scalar alignment from the equivalent array + // declaration." + size = elementSize * numElements; + alignment = elementAlignment; + return spirv::MatrixType::get(columnType, numColumns); +} + Type VulkanLayoutUtils::decorateType(spirv::RuntimeArrayType arrayType, VulkanLayoutUtils::Size &alignment) { auto elementType = arrayType.getElementType(); diff --git a/mlir/test/Dialect/SPIRV/IR/types.mlir b/mlir/test/Dialect/SPIRV/IR/types.mlir index 05ab91b6db6bd..b63a08d96e6af 100644 --- a/mlir/test/Dialect/SPIRV/IR/types.mlir +++ b/mlir/test/Dialect/SPIRV/IR/types.mlir @@ -497,6 +497,11 @@ func.func private @matrix_type(!spirv.matrix<4 x vector<4xf16>>) -> () // ----- +// CHECK: func private @matrix_type(!spirv.struct<(!spirv.matrix<3 x vector<3xf32>> [0])>) +func.func private @matrix_type(!spirv.struct<(!spirv.matrix<3 x vector<3xf32>> [0])>) -> () + +// ----- + // expected-error @+1 {{matrix is expected to have 2, 3, or 4 columns}} func.func private @matrix_invalid_size(!spirv.matrix<5 x vector<3xf32>>) -> ()