Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 31 additions & 3 deletions mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,8 @@ def SPV_NV_ray_tracing_motion_blur : I32EnumAttrCase<"SPV_NV_ray_tracing_m

def SPV_NVX_multiview_per_view_attributes : I32EnumAttrCase<"SPV_NVX_multiview_per_view_attributes", 5015>;

def SPV_ARM_tensors : I32EnumAttrCase<"SPV_ARM_tensors", 6000>;

def SPIRV_ExtensionAttr :
SPIRV_I32EnumAttr<"Extension", "supported SPIR-V extensions", "ext", [
SPV_KHR_16bit_storage, SPV_KHR_8bit_storage, SPV_KHR_device_group,
Expand All @@ -445,6 +447,7 @@ def SPIRV_ExtensionAttr :
SPV_EXT_shader_atomic_float_add, SPV_EXT_shader_atomic_float_min_max,
SPV_EXT_shader_image_int64, SPV_EXT_shader_atomic_float16_add,
SPV_EXT_mesh_shader,
SPV_ARM_tensors,
SPV_AMD_gpu_shader_half_float_fetch, SPV_AMD_shader_ballot,
SPV_AMD_shader_explicit_vertex_parameter, SPV_AMD_shader_fragment_mask,
SPV_AMD_shader_image_load_store_lod, SPV_AMD_texture_gather_bias_lod,
Expand Down Expand Up @@ -1311,6 +1314,24 @@ def SPIRV_C_GeometryStreams : I32EnumAttrCase<"Geome
def SPIRV_C_MultiViewport : I32EnumAttrCase<"MultiViewport", 57> {
list<I32EnumAttrCase> implies = [SPIRV_C_Geometry];
}
def SPIRV_C_TensorsARM : I32EnumAttrCase<"TensorsARM", 4174> {
list<I32EnumAttrCase> implies = [SPIRV_C_Int8];
list<Availability> availability = [
Extension<[SPV_ARM_tensors]>
];
}
def SPIRV_C_StorageTensorArrayDynamicIndexingEXT : I32EnumAttrCase<"StorageTensorArrayDynamicIndexingEXT", 4175> {
list<I32EnumAttrCase> implies = [SPIRV_C_TensorsARM, SPIRV_C_Shader];
list<Availability> availability = [
Extension<[SPV_ARM_tensors]>
];
}
def SPIRV_C_StorageTensorArrayNonUniformIndexingEXT : I32EnumAttrCase<"StorageTensorArrayNonUniformIndexingEXT", 4176> {
list<I32EnumAttrCase> implies = [SPIRV_C_TensorsARM, SPIRV_C_ShaderNonUniform];
list<Availability> availability = [
Extension<[SPV_ARM_tensors]>
];
}
def SPIRV_C_WorkgroupMemoryExplicitLayout8BitAccessKHR : I32EnumAttrCase<"WorkgroupMemoryExplicitLayout8BitAccessKHR", 4429> {
list<I32EnumAttrCase> implies = [SPIRV_C_WorkgroupMemoryExplicitLayoutKHR];
list<Availability> availability = [
Expand Down Expand Up @@ -1523,6 +1544,8 @@ def SPIRV_CapabilityAttr :
SPIRV_C_IntegerFunctions2INTEL, SPIRV_C_TessellationPointSize,
SPIRV_C_GeometryPointSize, SPIRV_C_ImageCubeArray, SPIRV_C_ImageRect,
SPIRV_C_GeometryStreams, SPIRV_C_MultiViewport,
SPIRV_C_TensorsARM, SPIRV_C_StorageTensorArrayDynamicIndexingEXT,
SPIRV_C_StorageTensorArrayNonUniformIndexingEXT,
SPIRV_C_WorkgroupMemoryExplicitLayout8BitAccessKHR, SPIRV_C_VariablePointers,
SPIRV_C_RayTraversalPrimitiveCullingKHR, SPIRV_C_SampleMaskOverrideCoverageNV,
SPIRV_C_GeometryShaderPassthroughNV, SPIRV_C_PerViewAttributesNV,
Expand Down Expand Up @@ -4179,7 +4202,7 @@ def SPIRV_IsPtrType : CPred<"::llvm::isa<::mlir::spirv::PointerType>($_self)">;
def SPIRV_IsRTArrayType : CPred<"::llvm::isa<::mlir::spirv::RuntimeArrayType>($_self)">;
def SPIRV_IsSampledImageType : CPred<"::llvm::isa<::mlir::spirv::SampledImageType>($_self)">;
def SPIRV_IsStructType : CPred<"::llvm::isa<::mlir::spirv::StructType>($_self)">;

def SPIRV_IsTensorArmType : CPred<"::llvm::isa<::mlir::spirv::TensorArmType>($_self)">;

// See https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_types
// for the definition of the following types and type categories.
Expand Down Expand Up @@ -4217,6 +4240,8 @@ def SPIRV_AnyStruct : DialectType<SPIRV_Dialect, SPIRV_IsStructType,
"any SPIR-V struct type">;
def SPIRV_AnySampledImage : DialectType<SPIRV_Dialect, SPIRV_IsSampledImageType,
"any SPIR-V sampled image type">;
def SPIRV_AnyTensorArm : DialectType<SPIRV_Dialect, SPIRV_IsTensorArmType,
"any SPIR-V tensorArm type">;

def SPIRV_Numerical : AnyTypeOf<[SPIRV_Integer, SPIRV_AnyFloat]>;
def SPIRV_Scalar : AnyTypeOf<[SPIRV_Numerical, SPIRV_Bool]>;
Expand All @@ -4228,7 +4253,7 @@ def SPIRV_Type : AnyTypeOf<[
SPIRV_Void, SPIRV_Bool, SPIRV_Integer, SPIRV_AnyFloat, SPIRV_Vector,
SPIRV_AnyPtr, SPIRV_AnyArray, SPIRV_AnyRTArray, SPIRV_AnyStruct,
SPIRV_AnyCooperativeMatrix, SPIRV_AnyMatrix, SPIRV_AnySampledImage,
SPIRV_AnyImage
SPIRV_AnyImage, SPIRV_AnyTensorArm
]>;

def SPIRV_SignedInt : SignedIntOfWidths<[8, 16, 32, 64]>;
Expand Down Expand Up @@ -4525,6 +4550,7 @@ def SPIRV_OC_OpGroupNonUniformBitwiseXor : I32EnumAttrCase<"OpGroupNonUnifo
def SPIRV_OC_OpGroupNonUniformLogicalAnd : I32EnumAttrCase<"OpGroupNonUniformLogicalAnd", 362>;
def SPIRV_OC_OpGroupNonUniformLogicalOr : I32EnumAttrCase<"OpGroupNonUniformLogicalOr", 363>;
def SPIRV_OC_OpGroupNonUniformLogicalXor : I32EnumAttrCase<"OpGroupNonUniformLogicalXor", 364>;
def SPIRV_OC_OpTypeTensorARM : I32EnumAttrCase<"OpTypeTensorARM", 4163>;
def SPIRV_OC_OpSubgroupBallotKHR : I32EnumAttrCase<"OpSubgroupBallotKHR", 4421>;
def SPIRV_OC_OpGroupNonUniformRotateKHR : I32EnumAttrCase<"OpGroupNonUniformRotateKHR", 4431>;
def SPIRV_OC_OpSDot : I32EnumAttrCase<"OpSDot", 4450>;
Expand Down Expand Up @@ -4638,7 +4664,9 @@ def SPIRV_OpcodeAttr :
SPIRV_OC_OpGroupNonUniformFMax, SPIRV_OC_OpGroupNonUniformBitwiseAnd,
SPIRV_OC_OpGroupNonUniformBitwiseOr, SPIRV_OC_OpGroupNonUniformBitwiseXor,
SPIRV_OC_OpGroupNonUniformLogicalAnd, SPIRV_OC_OpGroupNonUniformLogicalOr,
SPIRV_OC_OpGroupNonUniformLogicalXor, SPIRV_OC_OpSubgroupBallotKHR,
SPIRV_OC_OpGroupNonUniformLogicalXor,
SPIRV_OC_OpTypeTensorARM,
SPIRV_OC_OpSubgroupBallotKHR,
SPIRV_OC_OpGroupNonUniformRotateKHR, SPIRV_OC_OpSDot, SPIRV_OC_OpUDot,
SPIRV_OC_OpSUDot, SPIRV_OC_OpSDotAccSat, SPIRV_OC_OpUDotAccSat,
SPIRV_OC_OpSUDotAccSat, SPIRV_OC_OpTypeCooperativeMatrixKHR,
Expand Down
35 changes: 34 additions & 1 deletion mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ namespace spirv {
namespace detail {
struct ArrayTypeStorage;
struct CooperativeMatrixTypeStorage;
struct TensorArmTypeStorage;
struct ImageTypeStorage;
struct MatrixTypeStorage;
struct PointerTypeStorage;
Expand Down Expand Up @@ -96,7 +97,8 @@ class ScalarType : public SPIRVType {
std::optional<int64_t> getSizeInBytes();
};

// SPIR-V composite type: VectorType, SPIR-V ArrayType, or SPIR-V StructType.
// SPIR-V composite type: VectorType, SPIR-V ArrayType, SPIR-V
// StructType, or SPIR-V TensorArmType.
class CompositeType : public SPIRVType {
public:
using SPIRVType::SPIRVType;
Expand Down Expand Up @@ -477,6 +479,37 @@ class MatrixType : public Type::TypeBase<MatrixType, CompositeType,
std::optional<StorageClass> storage = std::nullopt);
};

/// SPIR-V TensorARM Type
class TensorArmType
: public Type::TypeBase<TensorArmType, CompositeType,
detail::TensorArmTypeStorage, ShapedType::Trait> {
public:
using Base::Base;

static constexpr StringLiteral name = "spirv.arm.tensor";

// TensorArm supports minimum rank of 1, hence an empty shape here means
// unranked.
static TensorArmType get(ArrayRef<int64_t> shape, Type elementType);
TensorArmType cloneWith(std::optional<ArrayRef<int64_t>> shape,
Type elementType) const;

static LogicalResult
verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
ArrayRef<int64_t> shape, Type elementType);

Type getElementType() const;
ArrayRef<int64_t> getShape() const;
unsigned getNumElements() const;
bool hasRank() const { return !getShape().empty(); }
operator ShapedType() const { return llvm::cast<ShapedType>(*this); }

void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
std::optional<StorageClass> storage = std::nullopt);
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
std::optional<StorageClass> storage = std::nullopt);
};

} // namespace spirv
} // namespace mlir

Expand Down
75 changes: 74 additions & 1 deletion mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,13 @@ static Type parseAndVerifyType(SPIRVDialect const &dialect,
<< t.getNumElements();
return Type();
}
} else if (auto t = dyn_cast<TensorArmType>(type)) {
if (!isa<ScalarType>(t.getElementType())) {
parser.emitError(
typeLoc, "only scalar element type allowed in tensor type but found ")
<< t.getElementType();
return Type();
}
} else {
parser.emitError(typeLoc, "cannot use ")
<< type << " to compose SPIR-V types";
Expand Down Expand Up @@ -363,6 +370,52 @@ static Type parseCooperativeMatrixType(SPIRVDialect const &dialect,
return CooperativeMatrixType::get(elementTy, dims[0], dims[1], scope, use);
}

// tensor-arm-type ::=
// `!spirv.arm.tensor` `<` dim0 `x` dim1 `x` ... `x` dimN `x` element-type`>`
static Type parseTensorArmType(SPIRVDialect const &dialect,
DialectAsmParser &parser) {
if (parser.parseLess())
return {};

bool unranked = false;
SmallVector<int64_t, 4> dims;
SMLoc countLoc = parser.getCurrentLocation();

if (parser.parseOptionalStar().succeeded()) {
unranked = true;
if (parser.parseXInDimensionList())
return {};
} else if (parser.parseDimensionList(dims, /*allowDynamic=*/true)) {
return {};
}

if (!unranked && dims.empty()) {
parser.emitError(countLoc, "arm.tensors do not support rank zero");
return {};
}

if (llvm::is_contained(dims, 0)) {
parser.emitError(countLoc, "arm.tensors do not support zero dimensions");
return {};
}

if (llvm::any_of(dims, [](int64_t dim) { return dim < 0; }) &&
llvm::any_of(dims, [](int64_t dim) { return dim > 0; })) {
parser.emitError(countLoc, "arm.tensor shape dimensions must be either "
"fully dynamic or completed shaped");
return {};
}

auto elementTy = parseAndVerifyType(dialect, parser);
if (!elementTy)
return {};

if (parser.parseGreater())
return {};

return TensorArmType::get(dims, elementTy);
}

// TODO: Reorder methods to be utilities first and parse*Type
// methods in alphabetical order
//
Expand Down Expand Up @@ -759,6 +812,8 @@ Type SPIRVDialect::parseType(DialectAsmParser &parser) const {
return parseStructType(*this, parser);
if (keyword == "matrix")
return parseMatrixType(*this, parser);
if (keyword == "arm.tensor")
return parseTensorArmType(*this, parser);
parser.emitError(parser.getNameLoc(), "unknown SPIR-V type: ") << keyword;
return Type();
}
Expand Down Expand Up @@ -855,10 +910,28 @@ static void print(MatrixType type, DialectAsmPrinter &os) {
os << ">";
}

static void print(TensorArmType type, DialectAsmPrinter &os) {
os << "arm.tensor<";

llvm::interleave(
type.getShape(), os,
[&](int64_t dim) {
if (ShapedType::isDynamic(dim))
os << '?';
else
os << dim;
},
"x");
if (!type.hasRank()) {
os << "*";
}
os << "x" << type.getElementType() << ">";
}

void SPIRVDialect::printType(Type type, DialectAsmPrinter &os) const {
TypeSwitch<Type>(type)
.Case<ArrayType, CooperativeMatrixType, PointerType, RuntimeArrayType,
ImageType, SampledImageType, StructType, MatrixType>(
ImageType, SampledImageType, StructType, MatrixType, TensorArmType>(
[&](auto type) { print(type, os); })
.Default([](Type) { llvm_unreachable("unhandled SPIR-V type"); });
}
Expand Down
6 changes: 6 additions & 0 deletions mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -547,6 +547,12 @@ ParseResult spirv::ConstantOp::parse(OpAsmParser &parser,
return failure();
}

if (llvm::isa<TensorArmType>(type)) {
if (parser.parseOptionalColon().succeeded())
if (parser.parseType(type))
return failure();
}

return parser.addTypeToList(type, result.types);
}

Expand Down
Loading
Loading