diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td index d2ba76cdad904..d874817e6888d 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td @@ -422,6 +422,8 @@ def SPV_NV_ray_tracing_motion_blur : I32EnumAttrCase<"SPV_NV_ray_tracing_m def SPV_NVX_multiview_per_view_attributes : I32EnumAttrCase<"SPV_NVX_multiview_per_view_attributes", 5015>; +def SPV_ARM_tensors : I32EnumAttrCase<"SPV_ARM_tensors", 6000>; + def SPIRV_ExtensionAttr : SPIRV_I32EnumAttr<"Extension", "supported SPIR-V extensions", "ext", [ SPV_KHR_16bit_storage, SPV_KHR_8bit_storage, SPV_KHR_device_group, @@ -445,6 +447,7 @@ def SPIRV_ExtensionAttr : SPV_EXT_shader_atomic_float_add, SPV_EXT_shader_atomic_float_min_max, SPV_EXT_shader_image_int64, SPV_EXT_shader_atomic_float16_add, SPV_EXT_mesh_shader, + SPV_ARM_tensors, SPV_AMD_gpu_shader_half_float_fetch, SPV_AMD_shader_ballot, SPV_AMD_shader_explicit_vertex_parameter, SPV_AMD_shader_fragment_mask, SPV_AMD_shader_image_load_store_lod, SPV_AMD_texture_gather_bias_lod, @@ -1311,6 +1314,24 @@ def SPIRV_C_GeometryStreams : I32EnumAttrCase<"Geome def SPIRV_C_MultiViewport : I32EnumAttrCase<"MultiViewport", 57> { list implies = [SPIRV_C_Geometry]; } +def SPIRV_C_TensorsARM : I32EnumAttrCase<"TensorsARM", 4174> { + list implies = [SPIRV_C_Int8]; + list availability = [ + Extension<[SPV_ARM_tensors]> + ]; +} +def SPIRV_C_StorageTensorArrayDynamicIndexingEXT : I32EnumAttrCase<"StorageTensorArrayDynamicIndexingEXT", 4175> { + list implies = [SPIRV_C_TensorsARM, SPIRV_C_Shader]; + list availability = [ + Extension<[SPV_ARM_tensors]> + ]; +} +def SPIRV_C_StorageTensorArrayNonUniformIndexingEXT : I32EnumAttrCase<"StorageTensorArrayNonUniformIndexingEXT", 4176> { + list implies = [SPIRV_C_TensorsARM, SPIRV_C_ShaderNonUniform]; + list availability = [ + Extension<[SPV_ARM_tensors]> + ]; +} def SPIRV_C_WorkgroupMemoryExplicitLayout8BitAccessKHR : I32EnumAttrCase<"WorkgroupMemoryExplicitLayout8BitAccessKHR", 4429> { list implies = [SPIRV_C_WorkgroupMemoryExplicitLayoutKHR]; list availability = [ @@ -1523,6 +1544,8 @@ def SPIRV_CapabilityAttr : SPIRV_C_IntegerFunctions2INTEL, SPIRV_C_TessellationPointSize, SPIRV_C_GeometryPointSize, SPIRV_C_ImageCubeArray, SPIRV_C_ImageRect, SPIRV_C_GeometryStreams, SPIRV_C_MultiViewport, + SPIRV_C_TensorsARM, SPIRV_C_StorageTensorArrayDynamicIndexingEXT, + SPIRV_C_StorageTensorArrayNonUniformIndexingEXT, SPIRV_C_WorkgroupMemoryExplicitLayout8BitAccessKHR, SPIRV_C_VariablePointers, SPIRV_C_RayTraversalPrimitiveCullingKHR, SPIRV_C_SampleMaskOverrideCoverageNV, SPIRV_C_GeometryShaderPassthroughNV, SPIRV_C_PerViewAttributesNV, @@ -4179,7 +4202,7 @@ def SPIRV_IsPtrType : CPred<"::llvm::isa<::mlir::spirv::PointerType>($_self)">; def SPIRV_IsRTArrayType : CPred<"::llvm::isa<::mlir::spirv::RuntimeArrayType>($_self)">; def SPIRV_IsSampledImageType : CPred<"::llvm::isa<::mlir::spirv::SampledImageType>($_self)">; def SPIRV_IsStructType : CPred<"::llvm::isa<::mlir::spirv::StructType>($_self)">; - +def SPIRV_IsTensorArmType : CPred<"::llvm::isa<::mlir::spirv::TensorArmType>($_self)">; // See https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_types // for the definition of the following types and type categories. @@ -4217,6 +4240,8 @@ def SPIRV_AnyStruct : DialectType; def SPIRV_AnySampledImage : DialectType; +def SPIRV_AnyTensorArm : DialectType; def SPIRV_Numerical : AnyTypeOf<[SPIRV_Integer, SPIRV_AnyFloat]>; def SPIRV_Scalar : AnyTypeOf<[SPIRV_Numerical, SPIRV_Bool]>; @@ -4228,7 +4253,7 @@ def SPIRV_Type : AnyTypeOf<[ SPIRV_Void, SPIRV_Bool, SPIRV_Integer, SPIRV_AnyFloat, SPIRV_Vector, SPIRV_AnyPtr, SPIRV_AnyArray, SPIRV_AnyRTArray, SPIRV_AnyStruct, SPIRV_AnyCooperativeMatrix, SPIRV_AnyMatrix, SPIRV_AnySampledImage, - SPIRV_AnyImage + SPIRV_AnyImage, SPIRV_AnyTensorArm ]>; def SPIRV_SignedInt : SignedIntOfWidths<[8, 16, 32, 64]>; @@ -4525,6 +4550,7 @@ def SPIRV_OC_OpGroupNonUniformBitwiseXor : I32EnumAttrCase<"OpGroupNonUnifo def SPIRV_OC_OpGroupNonUniformLogicalAnd : I32EnumAttrCase<"OpGroupNonUniformLogicalAnd", 362>; def SPIRV_OC_OpGroupNonUniformLogicalOr : I32EnumAttrCase<"OpGroupNonUniformLogicalOr", 363>; def SPIRV_OC_OpGroupNonUniformLogicalXor : I32EnumAttrCase<"OpGroupNonUniformLogicalXor", 364>; +def SPIRV_OC_OpTypeTensorARM : I32EnumAttrCase<"OpTypeTensorARM", 4163>; def SPIRV_OC_OpSubgroupBallotKHR : I32EnumAttrCase<"OpSubgroupBallotKHR", 4421>; def SPIRV_OC_OpGroupNonUniformRotateKHR : I32EnumAttrCase<"OpGroupNonUniformRotateKHR", 4431>; def SPIRV_OC_OpSDot : I32EnumAttrCase<"OpSDot", 4450>; @@ -4638,7 +4664,9 @@ def SPIRV_OpcodeAttr : SPIRV_OC_OpGroupNonUniformFMax, SPIRV_OC_OpGroupNonUniformBitwiseAnd, SPIRV_OC_OpGroupNonUniformBitwiseOr, SPIRV_OC_OpGroupNonUniformBitwiseXor, SPIRV_OC_OpGroupNonUniformLogicalAnd, SPIRV_OC_OpGroupNonUniformLogicalOr, - SPIRV_OC_OpGroupNonUniformLogicalXor, SPIRV_OC_OpSubgroupBallotKHR, + SPIRV_OC_OpGroupNonUniformLogicalXor, + SPIRV_OC_OpTypeTensorARM, + SPIRV_OC_OpSubgroupBallotKHR, SPIRV_OC_OpGroupNonUniformRotateKHR, SPIRV_OC_OpSDot, SPIRV_OC_OpUDot, SPIRV_OC_OpSUDot, SPIRV_OC_OpSDotAccSat, SPIRV_OC_OpUDotAccSat, SPIRV_OC_OpSUDotAccSat, SPIRV_OC_OpTypeCooperativeMatrixKHR, diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h index 787535d0a6bd2..212cba61d396c 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h @@ -29,6 +29,7 @@ namespace spirv { namespace detail { struct ArrayTypeStorage; struct CooperativeMatrixTypeStorage; +struct TensorArmTypeStorage; struct ImageTypeStorage; struct MatrixTypeStorage; struct PointerTypeStorage; @@ -96,7 +97,8 @@ class ScalarType : public SPIRVType { std::optional getSizeInBytes(); }; -// SPIR-V composite type: VectorType, SPIR-V ArrayType, or SPIR-V StructType. +// SPIR-V composite type: VectorType, SPIR-V ArrayType, SPIR-V +// StructType, or SPIR-V TensorArmType. class CompositeType : public SPIRVType { public: using SPIRVType::SPIRVType; @@ -477,6 +479,46 @@ class MatrixType : public Type::TypeBase storage = std::nullopt); }; +/// SPIR-V TensorARM Type +class TensorArmType + : public Type::TypeBase { +public: + using Base::Base; + + using ShapedTypeTraits = ShapedType::Trait; + using ShapedTypeTraits::getDimSize; + using ShapedTypeTraits::getDynamicDimIndex; + using ShapedTypeTraits::getElementTypeBitWidth; + using ShapedTypeTraits::getNumDynamicDims; + using ShapedTypeTraits::getNumElements; + using ShapedTypeTraits::getRank; + using ShapedTypeTraits::hasStaticShape; + using ShapedTypeTraits::isDynamicDim; + + static constexpr StringLiteral name = "spirv.arm.tensor"; + + // TensorArm supports minimum rank of 1, hence an empty shape here means + // unranked. + static TensorArmType get(ArrayRef shape, Type elementType); + TensorArmType cloneWith(std::optional> shape, + Type elementType) const; + + static LogicalResult + verifyInvariants(function_ref emitError, + ArrayRef shape, Type elementType); + + Type getElementType() const; + ArrayRef getShape() const; + bool hasRank() const { return !getShape().empty(); } + operator ShapedType() const { return llvm::cast(*this); } + + void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, + std::optional storage = std::nullopt); + void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, + std::optional storage = std::nullopt); +}; + } // namespace spirv } // namespace mlir diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp index a21acef1c4b43..88c7adf3dfcb3 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp @@ -194,6 +194,13 @@ static Type parseAndVerifyType(SPIRVDialect const &dialect, << t.getNumElements(); return Type(); } + } else if (auto t = dyn_cast(type)) { + if (!isa(t.getElementType())) { + parser.emitError( + typeLoc, "only scalar element type allowed in tensor type but found ") + << t.getElementType(); + return Type(); + } } else { parser.emitError(typeLoc, "cannot use ") << type << " to compose SPIR-V types"; @@ -363,6 +370,52 @@ static Type parseCooperativeMatrixType(SPIRVDialect const &dialect, return CooperativeMatrixType::get(elementTy, dims[0], dims[1], scope, use); } +// tensor-arm-type ::= +// `!spirv.arm.tensor` `<` dim0 `x` dim1 `x` ... `x` dimN `x` element-type`>` +static Type parseTensorArmType(SPIRVDialect const &dialect, + DialectAsmParser &parser) { + if (parser.parseLess()) + return {}; + + bool unranked = false; + SmallVector dims; + SMLoc countLoc = parser.getCurrentLocation(); + + if (parser.parseOptionalStar().succeeded()) { + unranked = true; + if (parser.parseXInDimensionList()) + return {}; + } else if (parser.parseDimensionList(dims, /*allowDynamic=*/true)) { + return {}; + } + + if (!unranked && dims.empty()) { + parser.emitError(countLoc, "arm.tensors do not support rank zero"); + return {}; + } + + if (llvm::is_contained(dims, 0)) { + parser.emitError(countLoc, "arm.tensors do not support zero dimensions"); + return {}; + } + + if (llvm::any_of(dims, [](int64_t dim) { return dim < 0; }) && + llvm::any_of(dims, [](int64_t dim) { return dim > 0; })) { + parser.emitError(countLoc, "arm.tensor shape dimensions must be either " + "fully dynamic or completed shaped"); + return {}; + } + + auto elementTy = parseAndVerifyType(dialect, parser); + if (!elementTy) + return {}; + + if (parser.parseGreater()) + return {}; + + return TensorArmType::get(dims, elementTy); +} + // TODO: Reorder methods to be utilities first and parse*Type // methods in alphabetical order // @@ -759,6 +812,8 @@ Type SPIRVDialect::parseType(DialectAsmParser &parser) const { return parseStructType(*this, parser); if (keyword == "matrix") return parseMatrixType(*this, parser); + if (keyword == "arm.tensor") + return parseTensorArmType(*this, parser); parser.emitError(parser.getNameLoc(), "unknown SPIR-V type: ") << keyword; return Type(); } @@ -855,10 +910,28 @@ static void print(MatrixType type, DialectAsmPrinter &os) { os << ">"; } +static void print(TensorArmType type, DialectAsmPrinter &os) { + os << "arm.tensor<"; + + llvm::interleave( + type.getShape(), os, + [&](int64_t dim) { + if (ShapedType::isDynamic(dim)) + os << '?'; + else + os << dim; + }, + "x"); + if (!type.hasRank()) { + os << "*"; + } + os << "x" << type.getElementType() << ">"; +} + void SPIRVDialect::printType(Type type, DialectAsmPrinter &os) const { TypeSwitch(type) .Case( + ImageType, SampledImageType, StructType, MatrixType, TensorArmType>( [&](auto type) { print(type, os); }) .Default([](Type) { llvm_unreachable("unhandled SPIR-V type"); }); } diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp index 7148027dae78d..eb2974d62fdd1 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp @@ -547,6 +547,12 @@ ParseResult spirv::ConstantOp::parse(OpAsmParser &parser, return failure(); } + if (llvm::isa(type)) { + if (parser.parseOptionalColon().succeeded()) + if (parser.parseType(type)) + return failure(); + } + return parser.addTypeToList(type, result.types); } diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp index 93e0c9b33c546..2b90df42af5cc 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp @@ -18,8 +18,9 @@ #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/TypeSwitch.h" +#include #include -#include +#include using namespace mlir; using namespace mlir::spirv; @@ -96,7 +97,7 @@ bool CompositeType::classof(Type type) { return isValid(vectorType); return llvm::isa(type); + spirv::StructType, spirv::TensorArmType>(type); } bool CompositeType::isValid(VectorType type) { @@ -107,8 +108,8 @@ bool CompositeType::isValid(VectorType type) { Type CompositeType::getElementType(unsigned index) const { return TypeSwitch(*this) - .Case( - [](auto type) { return type.getElementType(); }) + .Case([](auto type) { return type.getElementType(); }) .Case([](MatrixType type) { return type.getColumnType(); }) .Case( [index](StructType type) { return type.getElementType(index); }) @@ -125,6 +126,8 @@ unsigned CompositeType::getNumElements() const { return structType.getNumElements(); if (auto vectorType = llvm::dyn_cast(*this)) return vectorType.getNumElements(); + if (auto tensorArmType = dyn_cast(*this)) + return tensorArmType.getNumElements(); if (llvm::isa(*this)) { llvm_unreachable( "invalid to query number of elements of spirv Cooperative Matrix type"); @@ -151,6 +154,13 @@ void CompositeType::getExtensions( return llvm::cast(type.getElementType()) .getExtensions(extensions, storage); }) + .Case([&](TensorArmType type) { + static constexpr Extension ext{Extension::SPV_ARM_tensors}; + extensions.push_back(ext); + return llvm::cast(type.getElementType()) + .getExtensions(extensions, storage); + }) + .Default([](Type) { llvm_unreachable("invalid composite type"); }); } @@ -171,6 +181,12 @@ void CompositeType::getCapabilities( return llvm::cast(type.getElementType()) .getCapabilities(capabilities, storage); }) + .Case([&](TensorArmType type) { + static constexpr Capability cap{Capability::TensorsARM}; + capabilities.push_back(cap); + return llvm::cast(type.getElementType()) + .getCapabilities(capabilities, storage); + }) .Default([](Type) { llvm_unreachable("invalid composite type"); }); } @@ -186,6 +202,13 @@ std::optional CompositeType::getSizeInBytes() { return std::nullopt; return *elementSize * vectorType.getNumElements(); } + if (auto tensorArmType = llvm::dyn_cast(*this)) { + std::optional elementSize = + llvm::cast(tensorArmType.getElementType()).getSizeInBytes(); + if (!elementSize) + return std::nullopt; + return *elementSize * tensorArmType.getNumElements(); + } return std::nullopt; } @@ -691,6 +714,8 @@ bool SPIRVType::classof(Type type) { return true; if (auto vectorType = llvm::dyn_cast(type)) return CompositeType::isValid(vectorType); + if (auto tensorArmType = llvm::dyn_cast(type)) + return llvm::isa(tensorArmType.getElementType()); return false; } @@ -712,6 +737,8 @@ void SPIRVType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, matrixType.getExtensions(extensions, storage); } else if (auto ptrType = llvm::dyn_cast(*this)) { ptrType.getExtensions(extensions, storage); + } else if (auto tensorArmType = llvm::dyn_cast(*this)) { + tensorArmType.getExtensions(extensions, storage); } else { llvm_unreachable("invalid SPIR-V Type to getExtensions"); } @@ -732,6 +759,8 @@ void SPIRVType::getCapabilities( matrixType.getCapabilities(capabilities, storage); } else if (auto ptrType = llvm::dyn_cast(*this)) { ptrType.getCapabilities(capabilities, storage); + } else if (auto tensorArmType = llvm::dyn_cast(*this)) { + tensorArmType.getCapabilities(capabilities, storage); } else { llvm_unreachable("invalid SPIR-V Type to getCapabilities"); } @@ -1203,11 +1232,85 @@ void MatrixType::getCapabilities( llvm::cast(getColumnType()).getCapabilities(capabilities, storage); } +//===----------------------------------------------------------------------===// +// TensorArmType +//===----------------------------------------------------------------------===// + +struct spirv::detail::TensorArmTypeStorage final : TypeStorage { + using KeyTy = std::tuple, Type>; + + static TensorArmTypeStorage *construct(TypeStorageAllocator &allocator, + const KeyTy &key) { + auto [shape, elementType] = key; + shape = allocator.copyInto(shape); + return new (allocator.allocate()) + TensorArmTypeStorage(shape, elementType); + } + + static llvm::hash_code hashKey(const KeyTy &key) { + auto [shape, elementType] = key; + return llvm::hash_combine(shape, elementType); + } + + bool operator==(const KeyTy &key) const { + return key == KeyTy(shape, elementType); + } + + TensorArmTypeStorage(ArrayRef shape, Type elementType) + : shape(std::move(shape)), elementType(std::move(elementType)) {} + + ArrayRef shape; + Type elementType; +}; + +TensorArmType TensorArmType::get(ArrayRef shape, Type elementType) { + return Base::get(elementType.getContext(), shape, elementType); +} + +TensorArmType TensorArmType::cloneWith(std::optional> shape, + Type elementType) const { + return TensorArmType::get(shape.value_or(getShape()), elementType); +} + +Type TensorArmType::getElementType() const { return getImpl()->elementType; } +ArrayRef TensorArmType::getShape() const { return getImpl()->shape; } + +void TensorArmType::getExtensions( + SPIRVType::ExtensionArrayRefVector &extensions, + std::optional storage) { + + llvm::cast(getElementType()).getExtensions(extensions, storage); + static constexpr Extension ext{Extension::SPV_ARM_tensors}; + extensions.push_back(ext); +} + +void TensorArmType::getCapabilities( + SPIRVType::CapabilityArrayRefVector &capabilities, + std::optional storage) { + llvm::cast(getElementType()) + .getCapabilities(capabilities, storage); + static constexpr Capability cap{Capability::TensorsARM}; + capabilities.push_back(cap); +} + +LogicalResult +TensorArmType::verifyInvariants(function_ref emitError, + ArrayRef shape, Type elementType) { + if (llvm::is_contained(shape, 0)) + return emitError() << "arm.tensor do not support dimensions = 0"; + if (llvm::any_of(shape, [](int64_t dim) { return dim < 0; }) && + llvm::any_of(shape, [](int64_t dim) { return dim > 0; })) + return emitError() + << "arm.tensor shape dimensions must be either fully dynamic or " + "completed shaped"; + return success(); +} + //===----------------------------------------------------------------------===// // SPIR-V Dialect //===----------------------------------------------------------------------===// void SPIRVDialect::registerTypes() { addTypes(); + RuntimeArrayType, SampledImageType, StructType, TensorArmType>(); } diff --git a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp index b30da773d4896..55d6a380d0bff 100644 --- a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp +++ b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp @@ -164,6 +164,7 @@ LogicalResult spirv::Deserializer::processInstruction( case spirv::Opcode::OpTypeRuntimeArray: case spirv::Opcode::OpTypeStruct: case spirv::Opcode::OpTypePointer: + case spirv::Opcode::OpTypeTensorARM: case spirv::Opcode::OpTypeCooperativeMatrixKHR: return processType(opcode, operands); case spirv::Opcode::OpTypeForwardPointer: diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp index b9d9a9015eb61..b1abd8b3dffe9 100644 --- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp +++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp @@ -935,6 +935,8 @@ LogicalResult spirv::Deserializer::processType(spirv::Opcode opcode, return processStructType(operands); case spirv::Opcode::OpTypeMatrix: return processMatrixType(operands); + case spirv::Opcode::OpTypeTensorARM: + return processTensorARMType(operands); default: return emitError(unknownLoc, "unhandled type instruction"); } @@ -1238,6 +1240,55 @@ spirv::Deserializer::processMatrixType(ArrayRef operands) { return success(); } +LogicalResult +spirv::Deserializer::processTensorARMType(ArrayRef operands) { + unsigned size = operands.size(); + if (size < 2 || size > 4) + return emitError(unknownLoc, "OpTypeTensorARM must have 2-4 operands " + "(result_id, element_type, (rank), (shape)) ") + << size; + + Type elementTy = getType(operands[1]); + if (!elementTy) + return emitError(unknownLoc, + "OpTypeTensorARM references undefined element type ") + << operands[1]; + + if (size == 2) { + typeMap[operands[0]] = TensorArmType::get({}, elementTy); + return success(); + } + + IntegerAttr rankAttr = getConstantInt(operands[2]); + if (!rankAttr) + return emitError(unknownLoc, "OpTypeTensorARM rank must come from a " + "scalar integer constant instruction"); + unsigned rank = rankAttr.getValue().getZExtValue(); + if (size == 3) { + SmallVector shape(rank, ShapedType::kDynamic); + typeMap[operands[0]] = TensorArmType::get(shape, elementTy); + return success(); + } + + std::optional> shapeInfo = + getConstant(operands[3]); + if (!shapeInfo) + return emitError(unknownLoc, "OpTypeTensorARM shape must come from a " + "constant instruction of type OpTypeArray"); + + ArrayAttr shapeArrayAttr = llvm::dyn_cast(shapeInfo->first); + SmallVector shape; + for (auto dimAttr : shapeArrayAttr.getValue()) { + auto dimIntAttr = llvm::dyn_cast(dimAttr); + if (!dimIntAttr) + return emitError(unknownLoc, "OpTypeTensorARM shape has an invalid " + "dimension size"); + shape.push_back(dimIntAttr.getValue().getSExtValue()); + } + typeMap[operands[0]] = TensorArmType::get(shape, elementTy); + return success(); +} + LogicalResult spirv::Deserializer::processTypeForwardPointer(ArrayRef operands) { if (operands.size() != 2) diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h index e4556e7652b17..1bc9e4a3c75d8 100644 --- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h +++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h @@ -291,6 +291,8 @@ class Deserializer { LogicalResult processMatrixType(ArrayRef operands); + LogicalResult processTensorARMType(ArrayRef operands); + LogicalResult processTypeForwardPointer(ArrayRef operands); //===--------------------------------------------------------------------===// diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp index d258bfd852961..ebebd2d283afa 100644 --- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp +++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp @@ -729,6 +729,54 @@ LogicalResult Serializer::prepareBasicType( return success(); } + if (auto tensorArmType = llvm::dyn_cast(type)) { + uint32_t elementTypeID = 0; + uint32_t rank = 0; + uint32_t shapeID = 0; + uint32_t rankID = 0; + if (failed(processTypeImpl(loc, tensorArmType.getElementType(), + elementTypeID, serializationCtx))) { + return failure(); + } + if (tensorArmType.hasRank()) { + ArrayRef dims = tensorArmType.getShape(); + rank = dims.size(); + rankID = prepareConstantInt(loc, mlirBuilder.getI32IntegerAttr(rank)); + if (rankID == 0) { + return failure(); + } + + bool shaped = llvm::all_of(dims, [](const auto &dim) { return dim > 0; }); + if (rank > 0 && shaped) { + auto I32Type = IntegerType::get(type.getContext(), 32); + auto shapeType = ArrayType::get(I32Type, rank); + if (rank == 1) { + SmallVector index(rank); + shapeID = prepareDenseElementsConstant( + loc, shapeType, + mlirBuilder.getI32TensorAttr(SmallVector(dims)), 0, + index); + } else { + shapeID = prepareArrayConstant( + loc, shapeType, + mlirBuilder.getI32ArrayAttr(SmallVector(dims))); + } + if (shapeID == 0) { + return failure(); + } + } + } + typeEnum = spirv::Opcode::OpTypeTensorARM; + operands.push_back(elementTypeID); + if (rankID == 0) + return success(); + operands.push_back(rankID); + if (shapeID == 0) + return success(); + operands.push_back(shapeID); + return success(); + } + // TODO: Handle other types. return emitError(loc, "unhandled type in serialization: ") << type; } diff --git a/mlir/test/Dialect/SPIRV/IR/types.mlir b/mlir/test/Dialect/SPIRV/IR/types.mlir index c23894c62826b..7d45b5ea82643 100644 --- a/mlir/test/Dialect/SPIRV/IR/types.mlir +++ b/mlir/test/Dialect/SPIRV/IR/types.mlir @@ -564,3 +564,54 @@ func.func private @matrix_size_type(!spirv.matrix< x vector<3xi32>>) -> () func.func private @matrix_size_type(!spirv.matrix<2.0 x vector<3xi32>>) -> () // ----- + +//===----------------------------------------------------------------------===// +// TensorArm +//===----------------------------------------------------------------------===// + +// CHECK: func private @arm_tensor_type_single_dim_i32(!spirv.arm.tensor<1xi32>) +func.func private @arm_tensor_type_single_dim_i32(!spirv.arm.tensor<1xi32>) -> () + +// ----- + +// CHECK: func private @arm_tensor_type_multi_dim_i32(!spirv.arm.tensor<1x2x3xi32>) +func.func private @arm_tensor_type_multi_dim_i32(!spirv.arm.tensor<1x2x3xi32>) -> () + +// ----- + +// CHECK: func private @arm_tensor_type_single_dim_f16(!spirv.arm.tensor<1xf16>) +func.func private @arm_tensor_type_single_dim_f16(!spirv.arm.tensor<1xf16>) -> () + +// ----- + +// CHECK: func private @arm_tensor_type_multi_dim_f16(!spirv.arm.tensor<1x2x3xf16>) +func.func private @arm_tensor_type_multi_dim_f16(!spirv.arm.tensor<1x2x3xf16>) -> () + +// ----- + +// CHECK: func private @arm_tensor_type_dynamic_dim(!spirv.arm.tensor) +func.func private @arm_tensor_type_dynamic_dim(!spirv.arm.tensor) -> () + +// ----- + +// CHECK: func private @arm_tensor_type_dynamic_dim_2(!spirv.arm.tensor) +func.func private @arm_tensor_type_dynamic_dim_2(!spirv.arm.tensor) -> () +// ----- + +// expected-error @+1 {{arm.tensor shape dimensions must be either fully dynamic or completed shaped}} +func.func private @arm_tensor_type_dynamic_dim(!spirv.arm.tensor<1x?xi32>) -> () + +// ----- + +// expected-error @+1 {{arm.tensors do not support rank zero}} +func.func private @arm_tensor_rank_zero(!spirv.arm.tensor) -> () + +// ----- + +// CHECK: func private @arm_tensor_type_unranked(!spirv.arm.tensor<*xi32>) +func.func private @arm_tensor_type_unranked(!spirv.arm.tensor<*xi32>) -> () + +// ----- + +// expected-error @+1 {{arm.tensors do not support zero dimensions}} +func.func private @arm_tensor_type_zero_dim(!spirv.arm.tensor<0xi32>) -> () diff --git a/mlir/test/Target/SPIRV/tensorARM.mlir b/mlir/test/Target/SPIRV/tensorARM.mlir new file mode 100644 index 0000000000000..75b648ebfd008 --- /dev/null +++ b/mlir/test/Target/SPIRV/tensorARM.mlir @@ -0,0 +1,66 @@ +// RUN: mlir-translate -no-implicit-module -test-spirv-roundtrip %s | FileCheck %s + +spirv.module Logical GLSL450 requires #spirv.vce { + // CHECK: spirv.func @shaped_int_arm_tensor(%arg0: !spirv.arm.tensor<2xi32>) "None" { + spirv.func @shaped_int_arm_tensor(%arg0 : !spirv.arm.tensor<2xi32>) "None" { + spirv.Return + } + +// ----- + + // CHECK: spirv.func @shaped_rank2_int_arm_tensor(%arg0: !spirv.arm.tensor<2x3xi32>) "None" { + spirv.func @shaped_rank2_int_arm_tensor(%arg0 : !spirv.arm.tensor<2x3xi32>) "None" { + spirv.Return + } + +// ----- + + // CHECK: spirv.func @ui64_arm_tensor_const() -> !spirv.arm.tensor<3xi64> "None" { + spirv.func @ui64_arm_tensor_const() -> !spirv.arm.tensor<3xui64> "None" { + // CHECK: spirv.Constant dense<[5, 6, 7]> : !spirv.arm.tensor<3xi64> + %0 = spirv.Constant dense<[5, 6, 7]> : !spirv.arm.tensor<3xui64> + + spirv.ReturnValue %0: !spirv.arm.tensor<3xui64> + } + +// ----- + + // CHECK: spirv.func @si32_arm_tensor_const() -> !spirv.arm.tensor<3xsi32> "None" { + spirv.func @si32_arm_tensor_const() -> !spirv.arm.tensor<3xsi32> "None" { + // CHECK: spirv.Constant dense<[5, 6, 7]> : !spirv.arm.tensor<3xsi32> + %0 = spirv.Constant dense<[5, 6, 7]> : !spirv.arm.tensor<3xsi32> + + spirv.ReturnValue %0 : !spirv.arm.tensor<3xsi32> + } + +// ----- + + // CHECK: spirv.func @float_arm_tensor_const() -> !spirv.arm.tensor<3xf32> "None" { + spirv.func @float_arm_tensor_const() -> !spirv.arm.tensor<3xf32> "None" { + // CHECK: spirv.Constant dense<[3.000000e+00, 4.000000e+00, 5.000000e+00]> : !spirv.arm.tensor<3xf32> + %0 = spirv.Constant dense<[3., 4., 5.]> : !spirv.arm.tensor<3xf32> + + spirv.ReturnValue %0 : !spirv.arm.tensor<3xf32> + } + +// ----- + + // CHECK: spirv.func @unranked_int_arm_tensor(%arg0: !spirv.arm.tensor<*xi32>) "None" { + spirv.func @unranked_int_arm_tensor(%arg0 : !spirv.arm.tensor<*xi32>) "None" { + spirv.Return + } + +// ----- + + // CHECK: spirv.func @unshaped_int_arm_tensor(%arg0: !spirv.arm.tensor) "None" { + spirv.func @unshaped_int_arm_tensor(%arg0 : !spirv.arm.tensor) "None" { + spirv.Return + } + +// ----- + + // CHECK: spirv.func @unshaped_int_arm_tensor_2(%arg0: !spirv.arm.tensor) "None" { + spirv.func @unshaped_int_arm_tensor_2(%arg0 : !spirv.arm.tensor) "None" { + spirv.Return + } +}