From 136f901344c205e4fe93f165ff4b7e32490abc98 Mon Sep 17 00:00:00 2001 From: Davide Grohmann Date: Fri, 13 Jun 2025 17:03:09 +0200 Subject: [PATCH 1/6] Add support for SPV_ARM_tensors Signed-off-by: Davide Grohmann Signed-off-by: Mohammadreza Ameri Mahabadian Change-Id: If78909a47417ef3dda710847cfe90c34b984ff09 --- .../mlir/Dialect/SPIRV/IR/SPIRVBase.td | 34 ++++- .../mlir/Dialect/SPIRV/IR/SPIRVTypes.h | 35 ++++- mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp | 77 ++++++++++- mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp | 6 + mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp | 124 +++++++++++++++++- .../SPIRV/Deserialization/DeserializeOps.cpp | 1 + .../SPIRV/Deserialization/Deserializer.cpp | 52 ++++++++ .../SPIRV/Deserialization/Deserializer.h | 2 + .../Target/SPIRV/Serialization/Serializer.cpp | 48 +++++++ mlir/test/Dialect/SPIRV/IR/types.mlir | 51 +++++++ mlir/test/Target/SPIRV/tensorARM.mlir | 66 ++++++++++ 11 files changed, 487 insertions(+), 9 deletions(-) create mode 100644 mlir/test/Target/SPIRV/tensorARM.mlir diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td index d2ba76cdad904..d874817e6888d 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td @@ -422,6 +422,8 @@ def SPV_NV_ray_tracing_motion_blur : I32EnumAttrCase<"SPV_NV_ray_tracing_m def SPV_NVX_multiview_per_view_attributes : I32EnumAttrCase<"SPV_NVX_multiview_per_view_attributes", 5015>; +def SPV_ARM_tensors : I32EnumAttrCase<"SPV_ARM_tensors", 6000>; + def SPIRV_ExtensionAttr : SPIRV_I32EnumAttr<"Extension", "supported SPIR-V extensions", "ext", [ SPV_KHR_16bit_storage, SPV_KHR_8bit_storage, SPV_KHR_device_group, @@ -445,6 +447,7 @@ def SPIRV_ExtensionAttr : SPV_EXT_shader_atomic_float_add, SPV_EXT_shader_atomic_float_min_max, SPV_EXT_shader_image_int64, SPV_EXT_shader_atomic_float16_add, SPV_EXT_mesh_shader, + SPV_ARM_tensors, SPV_AMD_gpu_shader_half_float_fetch, SPV_AMD_shader_ballot, SPV_AMD_shader_explicit_vertex_parameter, SPV_AMD_shader_fragment_mask, SPV_AMD_shader_image_load_store_lod, SPV_AMD_texture_gather_bias_lod, @@ -1311,6 +1314,24 @@ def SPIRV_C_GeometryStreams : I32EnumAttrCase<"Geome def SPIRV_C_MultiViewport : I32EnumAttrCase<"MultiViewport", 57> { list implies = [SPIRV_C_Geometry]; } +def SPIRV_C_TensorsARM : I32EnumAttrCase<"TensorsARM", 4174> { + list implies = [SPIRV_C_Int8]; + list availability = [ + Extension<[SPV_ARM_tensors]> + ]; +} +def SPIRV_C_StorageTensorArrayDynamicIndexingEXT : I32EnumAttrCase<"StorageTensorArrayDynamicIndexingEXT", 4175> { + list implies = [SPIRV_C_TensorsARM, SPIRV_C_Shader]; + list availability = [ + Extension<[SPV_ARM_tensors]> + ]; +} +def SPIRV_C_StorageTensorArrayNonUniformIndexingEXT : I32EnumAttrCase<"StorageTensorArrayNonUniformIndexingEXT", 4176> { + list implies = [SPIRV_C_TensorsARM, SPIRV_C_ShaderNonUniform]; + list availability = [ + Extension<[SPV_ARM_tensors]> + ]; +} def SPIRV_C_WorkgroupMemoryExplicitLayout8BitAccessKHR : I32EnumAttrCase<"WorkgroupMemoryExplicitLayout8BitAccessKHR", 4429> { list implies = [SPIRV_C_WorkgroupMemoryExplicitLayoutKHR]; list availability = [ @@ -1523,6 +1544,8 @@ def SPIRV_CapabilityAttr : SPIRV_C_IntegerFunctions2INTEL, SPIRV_C_TessellationPointSize, SPIRV_C_GeometryPointSize, SPIRV_C_ImageCubeArray, SPIRV_C_ImageRect, SPIRV_C_GeometryStreams, SPIRV_C_MultiViewport, + SPIRV_C_TensorsARM, SPIRV_C_StorageTensorArrayDynamicIndexingEXT, + SPIRV_C_StorageTensorArrayNonUniformIndexingEXT, SPIRV_C_WorkgroupMemoryExplicitLayout8BitAccessKHR, SPIRV_C_VariablePointers, SPIRV_C_RayTraversalPrimitiveCullingKHR, SPIRV_C_SampleMaskOverrideCoverageNV, SPIRV_C_GeometryShaderPassthroughNV, SPIRV_C_PerViewAttributesNV, @@ -4179,7 +4202,7 @@ def SPIRV_IsPtrType : CPred<"::llvm::isa<::mlir::spirv::PointerType>($_self)">; def SPIRV_IsRTArrayType : CPred<"::llvm::isa<::mlir::spirv::RuntimeArrayType>($_self)">; def SPIRV_IsSampledImageType : CPred<"::llvm::isa<::mlir::spirv::SampledImageType>($_self)">; def SPIRV_IsStructType : CPred<"::llvm::isa<::mlir::spirv::StructType>($_self)">; - +def SPIRV_IsTensorArmType : CPred<"::llvm::isa<::mlir::spirv::TensorArmType>($_self)">; // See https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_types // for the definition of the following types and type categories. @@ -4217,6 +4240,8 @@ def SPIRV_AnyStruct : DialectType; def SPIRV_AnySampledImage : DialectType; +def SPIRV_AnyTensorArm : DialectType; def SPIRV_Numerical : AnyTypeOf<[SPIRV_Integer, SPIRV_AnyFloat]>; def SPIRV_Scalar : AnyTypeOf<[SPIRV_Numerical, SPIRV_Bool]>; @@ -4228,7 +4253,7 @@ def SPIRV_Type : AnyTypeOf<[ SPIRV_Void, SPIRV_Bool, SPIRV_Integer, SPIRV_AnyFloat, SPIRV_Vector, SPIRV_AnyPtr, SPIRV_AnyArray, SPIRV_AnyRTArray, SPIRV_AnyStruct, SPIRV_AnyCooperativeMatrix, SPIRV_AnyMatrix, SPIRV_AnySampledImage, - SPIRV_AnyImage + SPIRV_AnyImage, SPIRV_AnyTensorArm ]>; def SPIRV_SignedInt : SignedIntOfWidths<[8, 16, 32, 64]>; @@ -4525,6 +4550,7 @@ def SPIRV_OC_OpGroupNonUniformBitwiseXor : I32EnumAttrCase<"OpGroupNonUnifo def SPIRV_OC_OpGroupNonUniformLogicalAnd : I32EnumAttrCase<"OpGroupNonUniformLogicalAnd", 362>; def SPIRV_OC_OpGroupNonUniformLogicalOr : I32EnumAttrCase<"OpGroupNonUniformLogicalOr", 363>; def SPIRV_OC_OpGroupNonUniformLogicalXor : I32EnumAttrCase<"OpGroupNonUniformLogicalXor", 364>; +def SPIRV_OC_OpTypeTensorARM : I32EnumAttrCase<"OpTypeTensorARM", 4163>; def SPIRV_OC_OpSubgroupBallotKHR : I32EnumAttrCase<"OpSubgroupBallotKHR", 4421>; def SPIRV_OC_OpGroupNonUniformRotateKHR : I32EnumAttrCase<"OpGroupNonUniformRotateKHR", 4431>; def SPIRV_OC_OpSDot : I32EnumAttrCase<"OpSDot", 4450>; @@ -4638,7 +4664,9 @@ def SPIRV_OpcodeAttr : SPIRV_OC_OpGroupNonUniformFMax, SPIRV_OC_OpGroupNonUniformBitwiseAnd, SPIRV_OC_OpGroupNonUniformBitwiseOr, SPIRV_OC_OpGroupNonUniformBitwiseXor, SPIRV_OC_OpGroupNonUniformLogicalAnd, SPIRV_OC_OpGroupNonUniformLogicalOr, - SPIRV_OC_OpGroupNonUniformLogicalXor, SPIRV_OC_OpSubgroupBallotKHR, + SPIRV_OC_OpGroupNonUniformLogicalXor, + SPIRV_OC_OpTypeTensorARM, + SPIRV_OC_OpSubgroupBallotKHR, SPIRV_OC_OpGroupNonUniformRotateKHR, SPIRV_OC_OpSDot, SPIRV_OC_OpUDot, SPIRV_OC_OpSUDot, SPIRV_OC_OpSDotAccSat, SPIRV_OC_OpUDotAccSat, SPIRV_OC_OpSUDotAccSat, SPIRV_OC_OpTypeCooperativeMatrixKHR, diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h index 787535d0a6bd2..7ffea6e7dba81 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h @@ -29,6 +29,7 @@ namespace spirv { namespace detail { struct ArrayTypeStorage; struct CooperativeMatrixTypeStorage; +struct TensorArmTypeStorage; struct ImageTypeStorage; struct MatrixTypeStorage; struct PointerTypeStorage; @@ -96,7 +97,8 @@ class ScalarType : public SPIRVType { std::optional getSizeInBytes(); }; -// SPIR-V composite type: VectorType, SPIR-V ArrayType, or SPIR-V StructType. +// SPIR-V composite type: TensorArmType, VectorType, SPIR-V ArrayType, or SPIR-V +// StructType. class CompositeType : public SPIRVType { public: using SPIRVType::SPIRVType; @@ -477,6 +479,37 @@ class MatrixType : public Type::TypeBase storage = std::nullopt); }; +// SPIR-V TensorARM Type +class TensorArmType + : public Type::TypeBase { +public: + using Base::Base; + + static constexpr StringLiteral name = "spirv.arm.tensor"; + + // TensorArm supports minimum rank of 1, hence an empty shape here means + // unranked. + static TensorArmType get(ArrayRef shape, Type elementType); + TensorArmType cloneWith(std::optional> shape, + Type elementType) const; + + static LogicalResult + verifyInvariants(function_ref emitError, + ArrayRef shape, Type elementType); + + Type getElementType() const; + ArrayRef getShape() const; + unsigned getNumElements() const; + bool hasRank() const { return !getShape().empty(); } + operator ShapedType() const { return llvm::cast(*this); } + + void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, + std::optional storage = std::nullopt); + void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, + std::optional storage = std::nullopt); +}; + } // namespace spirv } // namespace mlir diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp index a21acef1c4b43..15002f1d5d16e 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp @@ -194,6 +194,13 @@ static Type parseAndVerifyType(SPIRVDialect const &dialect, << t.getNumElements(); return Type(); } + } else if (auto t = llvm::dyn_cast(type)) { + if (!llvm::isa(t.getElementType())) { + parser.emitError( + typeLoc, "only scalar element type allowed in tensor type but found ") + << t.getElementType(); + return Type(); + } } else { parser.emitError(typeLoc, "cannot use ") << type << " to compose SPIR-V types"; @@ -363,6 +370,54 @@ static Type parseCooperativeMatrixType(SPIRVDialect const &dialect, return CooperativeMatrixType::get(elementTy, dims[0], dims[1], scope, use); } +// tensor-arm-type ::= +// `!spirv.arm.tensor` `<` dim0 `x` dim1 `x` ... `x` dimN `x` element-type`>` +static Type parseTensorArmType(SPIRVDialect const &dialect, + DialectAsmParser &parser) { + if (parser.parseLess()) + return {}; + + bool unranked = false; + SmallVector dims; + SMLoc countLoc = parser.getCurrentLocation(); + + if (parser.parseOptionalStar().succeeded()) { + unranked = true; + if (parser.parseXInDimensionList()) + return {}; + } else if (parser.parseDimensionList(dims, /*allowDynamic=*/true)) + return {}; + + if (!unranked && dims.empty()) { + parser.emitError(countLoc, "arm.tensors do not support rank zero"); + return {}; + } + + if (std::any_of(dims.begin(), dims.end(), + [](int64_t dim) { return dim == 0; })) { + parser.emitError(countLoc, "arm.tensors do not support zero dimensions"); + return {}; + } + + if (std::any_of(dims.begin(), dims.end(), + [](int64_t dim) { return dim < 0; }) && + std::any_of(dims.begin(), dims.end(), + [](int64_t dim) { return dim > 0; })) { + parser.emitError(countLoc, "arm.tensor shape dimensions must be either " + "fully dynamic or completed shaped"); + return {}; + } + + auto elementTy = parseAndVerifyType(dialect, parser); + if (!elementTy) + return {}; + + if (parser.parseGreater()) + return {}; + + return TensorArmType::get(dims, elementTy); +} + // TODO: Reorder methods to be utilities first and parse*Type // methods in alphabetical order // @@ -759,6 +814,8 @@ Type SPIRVDialect::parseType(DialectAsmParser &parser) const { return parseStructType(*this, parser); if (keyword == "matrix") return parseMatrixType(*this, parser); + if (keyword == "arm.tensor") + return parseTensorArmType(*this, parser); parser.emitError(parser.getNameLoc(), "unknown SPIR-V type: ") << keyword; return Type(); } @@ -855,10 +912,28 @@ static void print(MatrixType type, DialectAsmPrinter &os) { os << ">"; } +static void print(TensorArmType type, DialectAsmPrinter &os) { + os << "arm.tensor<"; + + llvm::interleave( + type.getShape(), os, + [&](int64_t dim) { + if (ShapedType::isDynamic(dim)) + os << '?'; + else + os << dim; + }, + "x"); + if (!type.hasRank()) { + os << "*"; + } + os << "x" << type.getElementType() << ">"; +} + void SPIRVDialect::printType(Type type, DialectAsmPrinter &os) const { TypeSwitch(type) .Case( + ImageType, SampledImageType, StructType, MatrixType, TensorArmType>( [&](auto type) { print(type, os); }) .Default([](Type) { llvm_unreachable("unhandled SPIR-V type"); }); } diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp index 7148027dae78d..eb2974d62fdd1 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp @@ -547,6 +547,12 @@ ParseResult spirv::ConstantOp::parse(OpAsmParser &parser, return failure(); } + if (llvm::isa(type)) { + if (parser.parseOptionalColon().succeeded()) + if (parser.parseType(type)) + return failure(); + } + return parser.addTypeToList(type, result.types); } diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp index 93e0c9b33c546..e4eeb0a7f37d5 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp @@ -18,8 +18,10 @@ #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/TypeSwitch.h" +#include #include #include +#include using namespace mlir; using namespace mlir::spirv; @@ -96,7 +98,7 @@ bool CompositeType::classof(Type type) { return isValid(vectorType); return llvm::isa(type); + spirv::StructType, spirv::TensorArmType>(type); } bool CompositeType::isValid(VectorType type) { @@ -107,8 +109,8 @@ bool CompositeType::isValid(VectorType type) { Type CompositeType::getElementType(unsigned index) const { return TypeSwitch(*this) - .Case( - [](auto type) { return type.getElementType(); }) + .Case([](auto type) { return type.getElementType(); }) .Case([](MatrixType type) { return type.getColumnType(); }) .Case( [index](StructType type) { return type.getElementType(index); }) @@ -125,6 +127,8 @@ unsigned CompositeType::getNumElements() const { return structType.getNumElements(); if (auto vectorType = llvm::dyn_cast(*this)) return vectorType.getNumElements(); + if (auto tensorArmType = llvm::dyn_cast(*this)) + return tensorArmType.getNumElements(); if (llvm::isa(*this)) { llvm_unreachable( "invalid to query number of elements of spirv Cooperative Matrix type"); @@ -151,6 +155,14 @@ void CompositeType::getExtensions( return llvm::cast(type.getElementType()) .getExtensions(extensions, storage); }) + .Case([&](TensorArmType type) { + static const Extension exts[] = {Extension::SPV_ARM_tensors}; + ArrayRef ref(exts, std::size(exts)); + extensions.push_back(ref); + return llvm::cast(type.getElementType()) + .getExtensions(extensions, storage); + }) + .Default([](Type) { llvm_unreachable("invalid composite type"); }); } @@ -171,6 +183,13 @@ void CompositeType::getCapabilities( return llvm::cast(type.getElementType()) .getCapabilities(capabilities, storage); }) + .Case([&](TensorArmType type) { + static const Capability caps[] = {Capability::TensorsARM}; + ArrayRef ref(caps, std::size(caps)); + capabilities.push_back(ref); + return llvm::cast(type.getElementType()) + .getCapabilities(capabilities, storage); + }) .Default([](Type) { llvm_unreachable("invalid composite type"); }); } @@ -186,6 +205,13 @@ std::optional CompositeType::getSizeInBytes() { return std::nullopt; return *elementSize * vectorType.getNumElements(); } + if (auto tensorArmType = llvm::dyn_cast(*this)) { + std::optional elementSize = + llvm::cast(tensorArmType.getElementType()).getSizeInBytes(); + if (!elementSize) + return std::nullopt; + return *elementSize * tensorArmType.getNumElements(); + } return std::nullopt; } @@ -691,6 +717,9 @@ bool SPIRVType::classof(Type type) { return true; if (auto vectorType = llvm::dyn_cast(type)) return CompositeType::isValid(vectorType); + if (auto tensorArmType = llvm::dyn_cast(type)) { + return llvm::isa(tensorArmType.getElementType()); + } return false; } @@ -712,6 +741,8 @@ void SPIRVType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, matrixType.getExtensions(extensions, storage); } else if (auto ptrType = llvm::dyn_cast(*this)) { ptrType.getExtensions(extensions, storage); + } else if (auto tensorArmType = llvm::dyn_cast(*this)) { + tensorArmType.getExtensions(extensions, storage); } else { llvm_unreachable("invalid SPIR-V Type to getExtensions"); } @@ -732,6 +763,8 @@ void SPIRVType::getCapabilities( matrixType.getCapabilities(capabilities, storage); } else if (auto ptrType = llvm::dyn_cast(*this)) { ptrType.getCapabilities(capabilities, storage); + } else if (auto tensorArmType = llvm::dyn_cast(*this)) { + tensorArmType.getCapabilities(capabilities, storage); } else { llvm_unreachable("invalid SPIR-V Type to getCapabilities"); } @@ -1203,11 +1236,94 @@ void MatrixType::getCapabilities( llvm::cast(getColumnType()).getCapabilities(capabilities, storage); } +//===----------------------------------------------------------------------===// +// TensorArmType +//===----------------------------------------------------------------------===// + +struct spirv::detail::TensorArmTypeStorage final : TypeStorage { + using KeyTy = std::tuple, Type>; + + static TensorArmTypeStorage *construct(TypeStorageAllocator &allocator, + const KeyTy &key) { + auto shape = std::get<0>(key); + auto elementType = std::get<1>(key); + shape = allocator.copyInto(shape); + return new (allocator.allocate()) + TensorArmTypeStorage(std::move(shape), std::move(elementType)); + } + + static llvm::hash_code hashKey(const KeyTy &key) { + return llvm::hash_combine(std::get<0>(key), std::get<1>(key)); + } + + bool operator==(const KeyTy &key) const { + return key == KeyTy(shape, elementType); + } + + TensorArmTypeStorage(ArrayRef shape, Type elementType) + : shape(std::move(shape)), elementType(std::move(elementType)) {} + + ArrayRef shape; + Type elementType; +}; + +TensorArmType TensorArmType::get(ArrayRef shape, Type elementType) { + return Base::get(elementType.getContext(), shape, elementType); +} + +TensorArmType TensorArmType::cloneWith(std::optional> shape, + Type elementType) const { + return TensorArmType::get(shape.value_or(getShape()), elementType); +} + +Type TensorArmType::getElementType() const { return getImpl()->elementType; } +ArrayRef TensorArmType::getShape() const { return getImpl()->shape; } + +unsigned TensorArmType::getNumElements() const { + auto shape = getShape(); + return std::accumulate(shape.begin(), shape.end(), unsigned(1), + std::multiplies()); +} + +void TensorArmType::getExtensions( + SPIRVType::ExtensionArrayRefVector &extensions, + std::optional storage) { + + llvm::cast(getElementType()).getExtensions(extensions, storage); + static constexpr Extension exts[] = {Extension::SPV_ARM_tensors}; + extensions.push_back(exts); +} + +void TensorArmType::getCapabilities( + SPIRVType::CapabilityArrayRefVector &capabilities, + std::optional storage) { + llvm::cast(getElementType()) + .getCapabilities(capabilities, storage); + static constexpr Capability caps[] = {Capability::TensorsARM}; + capabilities.push_back(caps); +} + +LogicalResult +TensorArmType::verifyInvariants(function_ref emitError, + ArrayRef shape, Type elementType) { + if (std::any_of(shape.begin(), shape.end(), + [](int64_t dim) { return dim == 0; })) + return emitError() << "arm.tensor do not support dimensions = 0"; + if (std::any_of(shape.begin(), shape.end(), + [](int64_t dim) { return dim < 0; }) && + std::any_of(shape.begin(), shape.end(), + [](int64_t dim) { return dim > 0; })) + return emitError() + << "arm.tensor shape dimensions must be either fully dynamic or " + "completed shaped"; + return success(); +} + //===----------------------------------------------------------------------===// // SPIR-V Dialect //===----------------------------------------------------------------------===// void SPIRVDialect::registerTypes() { addTypes(); + RuntimeArrayType, SampledImageType, StructType, TensorArmType>(); } diff --git a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp index b30da773d4896..55d6a380d0bff 100644 --- a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp +++ b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp @@ -164,6 +164,7 @@ LogicalResult spirv::Deserializer::processInstruction( case spirv::Opcode::OpTypeRuntimeArray: case spirv::Opcode::OpTypeStruct: case spirv::Opcode::OpTypePointer: + case spirv::Opcode::OpTypeTensorARM: case spirv::Opcode::OpTypeCooperativeMatrixKHR: return processType(opcode, operands); case spirv::Opcode::OpTypeForwardPointer: diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp index b9d9a9015eb61..f0e42047c559e 100644 --- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp +++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp @@ -138,6 +138,7 @@ LogicalResult spirv::Deserializer::processHeader() { MIN_VERSION_CASE(3); MIN_VERSION_CASE(4); MIN_VERSION_CASE(5); + MIN_VERSION_CASE(6); #undef MIN_VERSION_CASE default: return emitError(unknownLoc, "unsupported SPIR-V minor version: ") @@ -935,6 +936,8 @@ LogicalResult spirv::Deserializer::processType(spirv::Opcode opcode, return processStructType(operands); case spirv::Opcode::OpTypeMatrix: return processMatrixType(operands); + case spirv::Opcode::OpTypeTensorARM: + return processTensorARMType(operands); default: return emitError(unknownLoc, "unhandled type instruction"); } @@ -1238,6 +1241,55 @@ spirv::Deserializer::processMatrixType(ArrayRef operands) { return success(); } +LogicalResult +spirv::Deserializer::processTensorARMType(ArrayRef operands) { + unsigned size = operands.size(); + if (size < 2 || size > 4) { + return emitError(unknownLoc, "OpTypeTensorARM must have 2-4 operands " + "(result_id, element_type, (rank), (shape))") + << size; + } + Type elementTy = getType(operands[1]); + if (!elementTy) { + return emitError(unknownLoc, + "OpTypeTensorARM references undefined element type.") + << operands[1]; + } + if (size == 2) { + typeMap[operands[0]] = TensorArmType::get({}, elementTy); + return success(); + } + + auto rankAttr = getConstantInt(operands[2]); + if (!rankAttr) + return emitError(unknownLoc, "OpTypeTensorARM rank must come from a " + "scalar integer constant instruction"); + unsigned rank = rankAttr.getValue().getZExtValue(); + if (size == 3) { + SmallVector shape(rank, ShapedType::kDynamic); + typeMap[operands[0]] = TensorArmType::get(shape, elementTy); + return success(); + } + + auto shapeInfo = getConstant(operands[3]); + if (!shapeInfo) { + return emitError(unknownLoc, "OpTypeTensorARM shape must come from a " + "constant instruction of type OpTypeArray"); + } + ArrayAttr shapeArrayAttr = llvm::dyn_cast(shapeInfo->first); + SmallVector shape; + for (auto dimAttr : shapeArrayAttr.getValue()) { + auto dimIntAttr = llvm::dyn_cast(dimAttr); + if (!dimIntAttr) { + return emitError(unknownLoc, "OpTypeTensorARM shape has an invalid " + "dimension size"); + } + shape.push_back(dimIntAttr.getValue().getSExtValue()); + } + typeMap[operands[0]] = TensorArmType::get(shape, elementTy); + return success(); +} + LogicalResult spirv::Deserializer::processTypeForwardPointer(ArrayRef operands) { if (operands.size() != 2) diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h index e4556e7652b17..1bc9e4a3c75d8 100644 --- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h +++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h @@ -291,6 +291,8 @@ class Deserializer { LogicalResult processMatrixType(ArrayRef operands); + LogicalResult processTensorARMType(ArrayRef operands); + LogicalResult processTypeForwardPointer(ArrayRef operands); //===--------------------------------------------------------------------===// diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp index d258bfd852961..ebebd2d283afa 100644 --- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp +++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp @@ -729,6 +729,54 @@ LogicalResult Serializer::prepareBasicType( return success(); } + if (auto tensorArmType = llvm::dyn_cast(type)) { + uint32_t elementTypeID = 0; + uint32_t rank = 0; + uint32_t shapeID = 0; + uint32_t rankID = 0; + if (failed(processTypeImpl(loc, tensorArmType.getElementType(), + elementTypeID, serializationCtx))) { + return failure(); + } + if (tensorArmType.hasRank()) { + ArrayRef dims = tensorArmType.getShape(); + rank = dims.size(); + rankID = prepareConstantInt(loc, mlirBuilder.getI32IntegerAttr(rank)); + if (rankID == 0) { + return failure(); + } + + bool shaped = llvm::all_of(dims, [](const auto &dim) { return dim > 0; }); + if (rank > 0 && shaped) { + auto I32Type = IntegerType::get(type.getContext(), 32); + auto shapeType = ArrayType::get(I32Type, rank); + if (rank == 1) { + SmallVector index(rank); + shapeID = prepareDenseElementsConstant( + loc, shapeType, + mlirBuilder.getI32TensorAttr(SmallVector(dims)), 0, + index); + } else { + shapeID = prepareArrayConstant( + loc, shapeType, + mlirBuilder.getI32ArrayAttr(SmallVector(dims))); + } + if (shapeID == 0) { + return failure(); + } + } + } + typeEnum = spirv::Opcode::OpTypeTensorARM; + operands.push_back(elementTypeID); + if (rankID == 0) + return success(); + operands.push_back(rankID); + if (shapeID == 0) + return success(); + operands.push_back(shapeID); + return success(); + } + // TODO: Handle other types. return emitError(loc, "unhandled type in serialization: ") << type; } diff --git a/mlir/test/Dialect/SPIRV/IR/types.mlir b/mlir/test/Dialect/SPIRV/IR/types.mlir index c23894c62826b..7d45b5ea82643 100644 --- a/mlir/test/Dialect/SPIRV/IR/types.mlir +++ b/mlir/test/Dialect/SPIRV/IR/types.mlir @@ -564,3 +564,54 @@ func.func private @matrix_size_type(!spirv.matrix< x vector<3xi32>>) -> () func.func private @matrix_size_type(!spirv.matrix<2.0 x vector<3xi32>>) -> () // ----- + +//===----------------------------------------------------------------------===// +// TensorArm +//===----------------------------------------------------------------------===// + +// CHECK: func private @arm_tensor_type_single_dim_i32(!spirv.arm.tensor<1xi32>) +func.func private @arm_tensor_type_single_dim_i32(!spirv.arm.tensor<1xi32>) -> () + +// ----- + +// CHECK: func private @arm_tensor_type_multi_dim_i32(!spirv.arm.tensor<1x2x3xi32>) +func.func private @arm_tensor_type_multi_dim_i32(!spirv.arm.tensor<1x2x3xi32>) -> () + +// ----- + +// CHECK: func private @arm_tensor_type_single_dim_f16(!spirv.arm.tensor<1xf16>) +func.func private @arm_tensor_type_single_dim_f16(!spirv.arm.tensor<1xf16>) -> () + +// ----- + +// CHECK: func private @arm_tensor_type_multi_dim_f16(!spirv.arm.tensor<1x2x3xf16>) +func.func private @arm_tensor_type_multi_dim_f16(!spirv.arm.tensor<1x2x3xf16>) -> () + +// ----- + +// CHECK: func private @arm_tensor_type_dynamic_dim(!spirv.arm.tensor) +func.func private @arm_tensor_type_dynamic_dim(!spirv.arm.tensor) -> () + +// ----- + +// CHECK: func private @arm_tensor_type_dynamic_dim_2(!spirv.arm.tensor) +func.func private @arm_tensor_type_dynamic_dim_2(!spirv.arm.tensor) -> () +// ----- + +// expected-error @+1 {{arm.tensor shape dimensions must be either fully dynamic or completed shaped}} +func.func private @arm_tensor_type_dynamic_dim(!spirv.arm.tensor<1x?xi32>) -> () + +// ----- + +// expected-error @+1 {{arm.tensors do not support rank zero}} +func.func private @arm_tensor_rank_zero(!spirv.arm.tensor) -> () + +// ----- + +// CHECK: func private @arm_tensor_type_unranked(!spirv.arm.tensor<*xi32>) +func.func private @arm_tensor_type_unranked(!spirv.arm.tensor<*xi32>) -> () + +// ----- + +// expected-error @+1 {{arm.tensors do not support zero dimensions}} +func.func private @arm_tensor_type_zero_dim(!spirv.arm.tensor<0xi32>) -> () diff --git a/mlir/test/Target/SPIRV/tensorARM.mlir b/mlir/test/Target/SPIRV/tensorARM.mlir new file mode 100644 index 0000000000000..25c2a25b47d88 --- /dev/null +++ b/mlir/test/Target/SPIRV/tensorARM.mlir @@ -0,0 +1,66 @@ +// RUN: mlir-translate -no-implicit-module -test-spirv-roundtrip %s | FileCheck %s + +spirv.module Logical GLSL450 requires #spirv.vce { + // CHECK: spirv.func @shaped_int_arm_tensor(%arg0: !spirv.arm.tensor<2xi32>) "None" { + spirv.func @shaped_int_arm_tensor(%arg0 : !spirv.arm.tensor<2xi32>) "None" { + spirv.Return + } + +// ----- + + // CHECK: spirv.func @shaped_rank2_int_arm_tensor(%arg0: !spirv.arm.tensor<2x3xi32>) "None" { + spirv.func @shaped_rank2_int_arm_tensor(%arg0 : !spirv.arm.tensor<2x3xi32>) "None" { + spirv.Return + } + +// ----- + + // CHECK: spirv.func @ui64_arm_tensor_const() -> !spirv.arm.tensor<3xi64> "None" { + spirv.func @ui64_arm_tensor_const() -> !spirv.arm.tensor<3xui64> "None" { + // CHECK: spirv.Constant dense<[5, 6, 7]> : !spirv.arm.tensor<3xi64> + %0 = spirv.Constant dense<[5, 6, 7]> : !spirv.arm.tensor<3xui64> + + spirv.ReturnValue %0: !spirv.arm.tensor<3xui64> + } + +// ----- + + // CHECK: spirv.func @si32_arm_tensor_const() -> !spirv.arm.tensor<3xsi32> "None" { + spirv.func @si32_arm_tensor_const() -> !spirv.arm.tensor<3xsi32> "None" { + // CHECK: spirv.Constant dense<[5, 6, 7]> : !spirv.arm.tensor<3xsi32> + %0 = spirv.Constant dense<[5, 6, 7]> : !spirv.arm.tensor<3xsi32> + + spirv.ReturnValue %0 : !spirv.arm.tensor<3xsi32> + } + +// ----- + + // CHECK: spirv.func @float_arm_tensor_const() -> !spirv.arm.tensor<3xf32> "None" { + spirv.func @float_arm_tensor_const() -> !spirv.arm.tensor<3xf32> "None" { + // CHECK: spirv.Constant dense<[3.000000e+00, 4.000000e+00, 5.000000e+00]> : !spirv.arm.tensor<3xf32> + %0 = spirv.Constant dense<[3., 4., 5.]> : !spirv.arm.tensor<3xf32> + + spirv.ReturnValue %0 : !spirv.arm.tensor<3xf32> + } + +// ----- + + // CHECK: spirv.func @unranked_int_arm_tensor(%arg0: !spirv.arm.tensor<*xi32>) "None" { + spirv.func @unranked_int_arm_tensor(%arg0 : !spirv.arm.tensor<*xi32>) "None" { + spirv.Return + } + +// ----- + + // CHECK: spirv.func @unshaped_int_arm_tensor(%arg0: !spirv.arm.tensor) "None" { + spirv.func @unshaped_int_arm_tensor(%arg0 : !spirv.arm.tensor) "None" { + spirv.Return + } + +// ----- + + // CHECK: spirv.func @unshaped_int_arm_tensor_2(%arg0: !spirv.arm.tensor) "None" { + spirv.func @unshaped_int_arm_tensor_2(%arg0 : !spirv.arm.tensor) "None" { + spirv.Return + } +} From 4750c732d13dd609924045df31dc685cf193849b Mon Sep 17 00:00:00 2001 From: Davide Grohmann Date: Thu, 19 Jun 2025 10:36:29 +0200 Subject: [PATCH 2/6] Resolve several review comments Signed-off-by: Davide Grohmann Change-Id: I9f99f2e5efc1b433bcf885a4890c730cb8cac213 --- .../mlir/Dialect/SPIRV/IR/SPIRVTypes.h | 6 ++-- mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp | 16 ++++----- mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp | 34 +++++++------------ .../SPIRV/Deserialization/Deserializer.cpp | 23 ++++++------- 4 files changed, 33 insertions(+), 46 deletions(-) diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h index 7ffea6e7dba81..2d3a4e1778490 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h @@ -97,8 +97,8 @@ class ScalarType : public SPIRVType { std::optional getSizeInBytes(); }; -// SPIR-V composite type: TensorArmType, VectorType, SPIR-V ArrayType, or SPIR-V -// StructType. +// SPIR-V composite type: VectorType, SPIR-V ArrayType, SPIR-V +// StructType, or SPIR-V TensorArmType. class CompositeType : public SPIRVType { public: using SPIRVType::SPIRVType; @@ -479,7 +479,7 @@ class MatrixType : public Type::TypeBase storage = std::nullopt); }; -// SPIR-V TensorARM Type +/// SPIR-V TensorARM Type class TensorArmType : public Type::TypeBase { diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp index 15002f1d5d16e..88c7adf3dfcb3 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp @@ -194,8 +194,8 @@ static Type parseAndVerifyType(SPIRVDialect const &dialect, << t.getNumElements(); return Type(); } - } else if (auto t = llvm::dyn_cast(type)) { - if (!llvm::isa(t.getElementType())) { + } else if (auto t = dyn_cast(type)) { + if (!isa(t.getElementType())) { parser.emitError( typeLoc, "only scalar element type allowed in tensor type but found ") << t.getElementType(); @@ -385,24 +385,22 @@ static Type parseTensorArmType(SPIRVDialect const &dialect, unranked = true; if (parser.parseXInDimensionList()) return {}; - } else if (parser.parseDimensionList(dims, /*allowDynamic=*/true)) + } else if (parser.parseDimensionList(dims, /*allowDynamic=*/true)) { return {}; + } if (!unranked && dims.empty()) { parser.emitError(countLoc, "arm.tensors do not support rank zero"); return {}; } - if (std::any_of(dims.begin(), dims.end(), - [](int64_t dim) { return dim == 0; })) { + if (llvm::is_contained(dims, 0)) { parser.emitError(countLoc, "arm.tensors do not support zero dimensions"); return {}; } - if (std::any_of(dims.begin(), dims.end(), - [](int64_t dim) { return dim < 0; }) && - std::any_of(dims.begin(), dims.end(), - [](int64_t dim) { return dim > 0; })) { + if (llvm::any_of(dims, [](int64_t dim) { return dim < 0; }) && + llvm::any_of(dims, [](int64_t dim) { return dim > 0; })) { parser.emitError(countLoc, "arm.tensor shape dimensions must be either " "fully dynamic or completed shaped"); return {}; diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp index e4eeb0a7f37d5..18e13a7d5d3ec 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp @@ -20,7 +20,6 @@ #include #include -#include #include using namespace mlir; @@ -127,7 +126,7 @@ unsigned CompositeType::getNumElements() const { return structType.getNumElements(); if (auto vectorType = llvm::dyn_cast(*this)) return vectorType.getNumElements(); - if (auto tensorArmType = llvm::dyn_cast(*this)) + if (auto tensorArmType = dyn_cast(*this)) return tensorArmType.getNumElements(); if (llvm::isa(*this)) { llvm_unreachable( @@ -156,9 +155,7 @@ void CompositeType::getExtensions( .getExtensions(extensions, storage); }) .Case([&](TensorArmType type) { - static const Extension exts[] = {Extension::SPV_ARM_tensors}; - ArrayRef ref(exts, std::size(exts)); - extensions.push_back(ref); + extensions.push_back({Extension::SPV_ARM_tensors}); return llvm::cast(type.getElementType()) .getExtensions(extensions, storage); }) @@ -184,9 +181,7 @@ void CompositeType::getCapabilities( .getCapabilities(capabilities, storage); }) .Case([&](TensorArmType type) { - static const Capability caps[] = {Capability::TensorsARM}; - ArrayRef ref(caps, std::size(caps)); - capabilities.push_back(ref); + capabilities.push_back({Capability::TensorsARM}); return llvm::cast(type.getElementType()) .getCapabilities(capabilities, storage); }) @@ -1245,15 +1240,15 @@ struct spirv::detail::TensorArmTypeStorage final : TypeStorage { static TensorArmTypeStorage *construct(TypeStorageAllocator &allocator, const KeyTy &key) { - auto shape = std::get<0>(key); - auto elementType = std::get<1>(key); + auto [shape, elementType] = key; shape = allocator.copyInto(shape); return new (allocator.allocate()) TensorArmTypeStorage(std::move(shape), std::move(elementType)); } static llvm::hash_code hashKey(const KeyTy &key) { - return llvm::hash_combine(std::get<0>(key), std::get<1>(key)); + auto [shape, elementType] = key; + return llvm::hash_combine(shape, elementType); } bool operator==(const KeyTy &key) const { @@ -1280,7 +1275,7 @@ Type TensorArmType::getElementType() const { return getImpl()->elementType; } ArrayRef TensorArmType::getShape() const { return getImpl()->shape; } unsigned TensorArmType::getNumElements() const { - auto shape = getShape(); + ArrayRef shape = getShape(); return std::accumulate(shape.begin(), shape.end(), unsigned(1), std::multiplies()); } @@ -1290,8 +1285,7 @@ void TensorArmType::getExtensions( std::optional storage) { llvm::cast(getElementType()).getExtensions(extensions, storage); - static constexpr Extension exts[] = {Extension::SPV_ARM_tensors}; - extensions.push_back(exts); + extensions.push_back({Extension::SPV_ARM_tensors}); } void TensorArmType::getCapabilities( @@ -1299,20 +1293,16 @@ void TensorArmType::getCapabilities( std::optional storage) { llvm::cast(getElementType()) .getCapabilities(capabilities, storage); - static constexpr Capability caps[] = {Capability::TensorsARM}; - capabilities.push_back(caps); + capabilities.push_back({Capability::TensorsARM}); } LogicalResult TensorArmType::verifyInvariants(function_ref emitError, ArrayRef shape, Type elementType) { - if (std::any_of(shape.begin(), shape.end(), - [](int64_t dim) { return dim == 0; })) + if (llvm::is_contained(shape, 0)) return emitError() << "arm.tensor do not support dimensions = 0"; - if (std::any_of(shape.begin(), shape.end(), - [](int64_t dim) { return dim < 0; }) && - std::any_of(shape.begin(), shape.end(), - [](int64_t dim) { return dim > 0; })) + if (llvm::any_of(shape, [](int64_t dim) { return dim < 0; }) && + llvm::any_of(shape, [](int64_t dim) { return dim > 0; })) return emitError() << "arm.tensor shape dimensions must be either fully dynamic or " "completed shaped"; diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp index f0e42047c559e..b801f5a4660fc 100644 --- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp +++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp @@ -1244,23 +1244,23 @@ spirv::Deserializer::processMatrixType(ArrayRef operands) { LogicalResult spirv::Deserializer::processTensorARMType(ArrayRef operands) { unsigned size = operands.size(); - if (size < 2 || size > 4) { + if (size < 2 || size > 4) return emitError(unknownLoc, "OpTypeTensorARM must have 2-4 operands " - "(result_id, element_type, (rank), (shape))") + "(result_id, element_type, (rank), (shape)) ") << size; - } + Type elementTy = getType(operands[1]); - if (!elementTy) { + if (!elementTy) return emitError(unknownLoc, - "OpTypeTensorARM references undefined element type.") + "OpTypeTensorARM references undefined element type ") << operands[1]; - } + if (size == 2) { typeMap[operands[0]] = TensorArmType::get({}, elementTy); return success(); } - auto rankAttr = getConstantInt(operands[2]); + IntegerAttr rankAttr = getConstantInt(operands[2]); if (!rankAttr) return emitError(unknownLoc, "OpTypeTensorARM rank must come from a " "scalar integer constant instruction"); @@ -1271,19 +1271,18 @@ spirv::Deserializer::processTensorARMType(ArrayRef operands) { return success(); } - auto shapeInfo = getConstant(operands[3]); - if (!shapeInfo) { + std::optional> shapeInfo = getConstant(operands[3]); + if (!shapeInfo) return emitError(unknownLoc, "OpTypeTensorARM shape must come from a " "constant instruction of type OpTypeArray"); - } + ArrayAttr shapeArrayAttr = llvm::dyn_cast(shapeInfo->first); SmallVector shape; for (auto dimAttr : shapeArrayAttr.getValue()) { auto dimIntAttr = llvm::dyn_cast(dimAttr); - if (!dimIntAttr) { + if (!dimIntAttr) return emitError(unknownLoc, "OpTypeTensorARM shape has an invalid " "dimension size"); - } shape.push_back(dimIntAttr.getValue().getSExtValue()); } typeMap[operands[0]] = TensorArmType::get(shape, elementTy); From a96196f9a5597d9c6d8774836fc1ab95af4493aa Mon Sep 17 00:00:00 2001 From: Davide Grohmann Date: Thu, 19 Jun 2025 13:53:13 +0200 Subject: [PATCH 3/6] Fix more comment and formatting Signed-off-by: Davide Grohmann Change-Id: I527cdeb796135237ef5d0965d5f8d54a36515458 --- mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp | 2 +- mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp index 18e13a7d5d3ec..e3ed16da2a6de 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp @@ -1243,7 +1243,7 @@ struct spirv::detail::TensorArmTypeStorage final : TypeStorage { auto [shape, elementType] = key; shape = allocator.copyInto(shape); return new (allocator.allocate()) - TensorArmTypeStorage(std::move(shape), std::move(elementType)); + TensorArmTypeStorage(shape, elementType); } static llvm::hash_code hashKey(const KeyTy &key) { diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp index b801f5a4660fc..893aa38da93d1 100644 --- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp +++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp @@ -1271,7 +1271,8 @@ spirv::Deserializer::processTensorARMType(ArrayRef operands) { return success(); } - std::optional> shapeInfo = getConstant(operands[3]); + std::optional> shapeInfo = + getConstant(operands[3]); if (!shapeInfo) return emitError(unknownLoc, "OpTypeTensorARM shape must come from a " "constant instruction of type OpTypeArray"); From d4464e595ea313b106bf9069005e99fafc23f237 Mon Sep 17 00:00:00 2001 From: Davide Grohmann Date: Wed, 25 Jun 2025 12:21:41 +0200 Subject: [PATCH 4/6] Resolve more review comments Signed-off-by: Davide Grohmann Change-Id: I9ee3c5755d123fcf3ef203ac5d76af9511c5ef7a --- mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h | 10 +++++++++- mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp | 15 +++++---------- .../Target/SPIRV/Deserialization/Deserializer.cpp | 1 - mlir/test/Target/SPIRV/tensorARM.mlir | 2 +- 4 files changed, 15 insertions(+), 13 deletions(-) diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h index 2d3a4e1778490..6fa09888e887b 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h @@ -486,6 +486,15 @@ class TensorArmType public: using Base::Base; + using ShapedType::Trait::getElementTypeBitWidth; + using ShapedType::Trait::getRank; + using ShapedType::Trait::getNumElements; + using ShapedType::Trait::isDynamicDim; + using ShapedType::Trait::hasStaticShape; + using ShapedType::Trait::getNumDynamicDims; + using ShapedType::Trait::getDimSize; + using ShapedType::Trait::getDynamicDimIndex; + static constexpr StringLiteral name = "spirv.arm.tensor"; // TensorArm supports minimum rank of 1, hence an empty shape here means @@ -500,7 +509,6 @@ class TensorArmType Type getElementType() const; ArrayRef getShape() const; - unsigned getNumElements() const; bool hasRank() const { return !getShape().empty(); } operator ShapedType() const { return llvm::cast(*this); } diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp index e3ed16da2a6de..b288eb2edc315 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp @@ -155,7 +155,8 @@ void CompositeType::getExtensions( .getExtensions(extensions, storage); }) .Case([&](TensorArmType type) { - extensions.push_back({Extension::SPV_ARM_tensors}); + static constexpr Extension ext{Extension::SPV_ARM_tensors}; + extensions.push_back(ext); return llvm::cast(type.getElementType()) .getExtensions(extensions, storage); }) @@ -181,7 +182,8 @@ void CompositeType::getCapabilities( .getCapabilities(capabilities, storage); }) .Case([&](TensorArmType type) { - capabilities.push_back({Capability::TensorsARM}); + static constexpr Capability cap{Capability::TensorsARM}; + capabilities.push_back(cap); return llvm::cast(type.getElementType()) .getCapabilities(capabilities, storage); }) @@ -712,9 +714,8 @@ bool SPIRVType::classof(Type type) { return true; if (auto vectorType = llvm::dyn_cast(type)) return CompositeType::isValid(vectorType); - if (auto tensorArmType = llvm::dyn_cast(type)) { + if (auto tensorArmType = llvm::dyn_cast(type)) return llvm::isa(tensorArmType.getElementType()); - } return false; } @@ -1274,12 +1275,6 @@ TensorArmType TensorArmType::cloneWith(std::optional> shape, Type TensorArmType::getElementType() const { return getImpl()->elementType; } ArrayRef TensorArmType::getShape() const { return getImpl()->shape; } -unsigned TensorArmType::getNumElements() const { - ArrayRef shape = getShape(); - return std::accumulate(shape.begin(), shape.end(), unsigned(1), - std::multiplies()); -} - void TensorArmType::getExtensions( SPIRVType::ExtensionArrayRefVector &extensions, std::optional storage) { diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp index 893aa38da93d1..b1abd8b3dffe9 100644 --- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp +++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp @@ -138,7 +138,6 @@ LogicalResult spirv::Deserializer::processHeader() { MIN_VERSION_CASE(3); MIN_VERSION_CASE(4); MIN_VERSION_CASE(5); - MIN_VERSION_CASE(6); #undef MIN_VERSION_CASE default: return emitError(unknownLoc, "unsupported SPIR-V minor version: ") diff --git a/mlir/test/Target/SPIRV/tensorARM.mlir b/mlir/test/Target/SPIRV/tensorARM.mlir index 25c2a25b47d88..75b648ebfd008 100644 --- a/mlir/test/Target/SPIRV/tensorARM.mlir +++ b/mlir/test/Target/SPIRV/tensorARM.mlir @@ -1,6 +1,6 @@ // RUN: mlir-translate -no-implicit-module -test-spirv-roundtrip %s | FileCheck %s -spirv.module Logical GLSL450 requires #spirv.vce { +spirv.module Logical GLSL450 requires #spirv.vce { // CHECK: spirv.func @shaped_int_arm_tensor(%arg0: !spirv.arm.tensor<2xi32>) "None" { spirv.func @shaped_int_arm_tensor(%arg0 : !spirv.arm.tensor<2xi32>) "None" { spirv.Return From 9be70df51d05aca2ec3b3d58b88521e6a8ef5198 Mon Sep 17 00:00:00 2001 From: Davide Grohmann Date: Wed, 25 Jun 2025 12:39:39 +0200 Subject: [PATCH 5/6] Resolve small review comments Signed-off-by: Davide Grohmann Change-Id: I8a67bb45402ecb7cfb90c2eaa85e3879783af415 --- mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp index b288eb2edc315..2b90df42af5cc 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp @@ -1280,7 +1280,8 @@ void TensorArmType::getExtensions( std::optional storage) { llvm::cast(getElementType()).getExtensions(extensions, storage); - extensions.push_back({Extension::SPV_ARM_tensors}); + static constexpr Extension ext{Extension::SPV_ARM_tensors}; + extensions.push_back(ext); } void TensorArmType::getCapabilities( @@ -1288,7 +1289,8 @@ void TensorArmType::getCapabilities( std::optional storage) { llvm::cast(getElementType()) .getCapabilities(capabilities, storage); - capabilities.push_back({Capability::TensorsARM}); + static constexpr Capability cap{Capability::TensorsARM}; + capabilities.push_back(cap); } LogicalResult From eabfff75d2f5adad40b7ccec708f3c549a176099 Mon Sep 17 00:00:00 2001 From: Davide Grohmann Date: Tue, 1 Jul 2025 23:25:26 +0200 Subject: [PATCH 6/6] Resolve last review comment Signed-off-by: Davide Grohmann Change-Id: Ief1c782e1ab637eb883b0bd6d324198419b9de0c --- mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h index 6fa09888e887b..212cba61d396c 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h @@ -486,14 +486,15 @@ class TensorArmType public: using Base::Base; - using ShapedType::Trait::getElementTypeBitWidth; - using ShapedType::Trait::getRank; - using ShapedType::Trait::getNumElements; - using ShapedType::Trait::isDynamicDim; - using ShapedType::Trait::hasStaticShape; - using ShapedType::Trait::getNumDynamicDims; - using ShapedType::Trait::getDimSize; - using ShapedType::Trait::getDynamicDimIndex; + using ShapedTypeTraits = ShapedType::Trait; + using ShapedTypeTraits::getDimSize; + using ShapedTypeTraits::getDynamicDimIndex; + using ShapedTypeTraits::getElementTypeBitWidth; + using ShapedTypeTraits::getNumDynamicDims; + using ShapedTypeTraits::getNumElements; + using ShapedTypeTraits::getRank; + using ShapedTypeTraits::hasStaticShape; + using ShapedTypeTraits::isDynamicDim; static constexpr StringLiteral name = "spirv.arm.tensor";