Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,8 @@ class ScalarType : public SPIRVType {
std::optional<int64_t> getSizeInBytes();
};

// SPIR-V composite type: TensorArmType, VectorType, SPIR-V ArrayType, or SPIR-V
// StructType.
// SPIR-V composite type: VectorType, SPIR-V ArrayType, SPIR-V
// StructType, or SPIR-V TensorArmType.
class CompositeType : public SPIRVType {
public:
using SPIRVType::SPIRVType;
Expand Down Expand Up @@ -479,7 +479,7 @@ class MatrixType : public Type::TypeBase<MatrixType, CompositeType,
std::optional<StorageClass> storage = std::nullopt);
};

// SPIR-V TensorARM Type
/// SPIR-V TensorARM Type
class TensorArmType
: public Type::TypeBase<TensorArmType, CompositeType,
detail::TensorArmTypeStorage, ShapedType::Trait> {
Expand Down
16 changes: 7 additions & 9 deletions mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -194,8 +194,8 @@ static Type parseAndVerifyType(SPIRVDialect const &dialect,
<< t.getNumElements();
return Type();
}
} else if (auto t = llvm::dyn_cast<TensorArmType>(type)) {
if (!llvm::isa<ScalarType>(t.getElementType())) {
} else if (auto t = dyn_cast<TensorArmType>(type)) {
if (!isa<ScalarType>(t.getElementType())) {
parser.emitError(
typeLoc, "only scalar element type allowed in tensor type but found ")
<< t.getElementType();
Expand Down Expand Up @@ -385,24 +385,22 @@ static Type parseTensorArmType(SPIRVDialect const &dialect,
unranked = true;
if (parser.parseXInDimensionList())
return {};
} else if (parser.parseDimensionList(dims, /*allowDynamic=*/true))
} else if (parser.parseDimensionList(dims, /*allowDynamic=*/true)) {
return {};
}

if (!unranked && dims.empty()) {
parser.emitError(countLoc, "arm.tensors do not support rank zero");
return {};
}

if (std::any_of(dims.begin(), dims.end(),
[](int64_t dim) { return dim == 0; })) {
if (llvm::is_contained(dims, 0)) {
parser.emitError(countLoc, "arm.tensors do not support zero dimensions");
return {};
}

if (std::any_of(dims.begin(), dims.end(),
[](int64_t dim) { return dim < 0; }) &&
std::any_of(dims.begin(), dims.end(),
[](int64_t dim) { return dim > 0; })) {
if (llvm::any_of(dims, [](int64_t dim) { return dim < 0; }) &&
llvm::any_of(dims, [](int64_t dim) { return dim > 0; })) {
parser.emitError(countLoc, "arm.tensor shape dimensions must be either "
"fully dynamic or completed shaped");
return {};
Expand Down
34 changes: 12 additions & 22 deletions mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@

#include <algorithm>
#include <cstdint>
#include <iterator>
#include <numeric>

using namespace mlir;
Expand Down Expand Up @@ -127,7 +126,7 @@ unsigned CompositeType::getNumElements() const {
return structType.getNumElements();
if (auto vectorType = llvm::dyn_cast<VectorType>(*this))
return vectorType.getNumElements();
if (auto tensorArmType = llvm::dyn_cast<TensorArmType>(*this))
if (auto tensorArmType = dyn_cast<TensorArmType>(*this))
return tensorArmType.getNumElements();
if (llvm::isa<CooperativeMatrixType>(*this)) {
llvm_unreachable(
Expand Down Expand Up @@ -156,9 +155,7 @@ void CompositeType::getExtensions(
.getExtensions(extensions, storage);
})
.Case<TensorArmType>([&](TensorArmType type) {
static const Extension exts[] = {Extension::SPV_ARM_tensors};
ArrayRef<Extension> ref(exts, std::size(exts));
extensions.push_back(ref);
extensions.push_back({Extension::SPV_ARM_tensors});
return llvm::cast<ScalarType>(type.getElementType())
.getExtensions(extensions, storage);
})
Expand All @@ -184,9 +181,7 @@ void CompositeType::getCapabilities(
.getCapabilities(capabilities, storage);
})
.Case<TensorArmType>([&](TensorArmType type) {
static const Capability caps[] = {Capability::TensorsARM};
ArrayRef<Capability> ref(caps, std::size(caps));
capabilities.push_back(ref);
capabilities.push_back({Capability::TensorsARM});
return llvm::cast<ScalarType>(type.getElementType())
.getCapabilities(capabilities, storage);
})
Expand Down Expand Up @@ -1245,15 +1240,15 @@ struct spirv::detail::TensorArmTypeStorage final : TypeStorage {

static TensorArmTypeStorage *construct(TypeStorageAllocator &allocator,
const KeyTy &key) {
auto shape = std::get<0>(key);
auto elementType = std::get<1>(key);
auto [shape, elementType] = key;
shape = allocator.copyInto(shape);
return new (allocator.allocate<TensorArmTypeStorage>())
TensorArmTypeStorage(std::move(shape), std::move(elementType));
}

static llvm::hash_code hashKey(const KeyTy &key) {
return llvm::hash_combine(std::get<0>(key), std::get<1>(key));
auto [shape, elementType] = key;
return llvm::hash_combine(shape, elementType);
}

bool operator==(const KeyTy &key) const {
Expand All @@ -1280,7 +1275,7 @@ Type TensorArmType::getElementType() const { return getImpl()->elementType; }
ArrayRef<int64_t> TensorArmType::getShape() const { return getImpl()->shape; }

unsigned TensorArmType::getNumElements() const {
auto shape = getShape();
ArrayRef<int64_t> shape = getShape();
return std::accumulate(shape.begin(), shape.end(), unsigned(1),
std::multiplies<unsigned>());
}
Expand All @@ -1290,29 +1285,24 @@ void TensorArmType::getExtensions(
std::optional<StorageClass> storage) {

llvm::cast<SPIRVType>(getElementType()).getExtensions(extensions, storage);
static constexpr Extension exts[] = {Extension::SPV_ARM_tensors};
extensions.push_back(exts);
extensions.push_back({Extension::SPV_ARM_tensors});
}

void TensorArmType::getCapabilities(
SPIRVType::CapabilityArrayRefVector &capabilities,
std::optional<StorageClass> storage) {
llvm::cast<SPIRVType>(getElementType())
.getCapabilities(capabilities, storage);
static constexpr Capability caps[] = {Capability::TensorsARM};
capabilities.push_back(caps);
capabilities.push_back({Capability::TensorsARM});
}

LogicalResult
TensorArmType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
ArrayRef<int64_t> shape, Type elementType) {
if (std::any_of(shape.begin(), shape.end(),
[](int64_t dim) { return dim == 0; }))
if (llvm::is_contained(shape, 0))
return emitError() << "arm.tensor do not support dimensions = 0";
if (std::any_of(shape.begin(), shape.end(),
[](int64_t dim) { return dim < 0; }) &&
std::any_of(shape.begin(), shape.end(),
[](int64_t dim) { return dim > 0; }))
if (llvm::any_of(shape, [](int64_t dim) { return dim < 0; }) &&
llvm::any_of(shape, [](int64_t dim) { return dim > 0; }))
return emitError()
<< "arm.tensor shape dimensions must be either fully dynamic or "
"completed shaped";
Expand Down
23 changes: 11 additions & 12 deletions mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1244,23 +1244,23 @@ spirv::Deserializer::processMatrixType(ArrayRef<uint32_t> operands) {
LogicalResult
spirv::Deserializer::processTensorARMType(ArrayRef<uint32_t> operands) {
unsigned size = operands.size();
if (size < 2 || size > 4) {
if (size < 2 || size > 4)
return emitError(unknownLoc, "OpTypeTensorARM must have 2-4 operands "
"(result_id, element_type, (rank), (shape))")
"(result_id, element_type, (rank), (shape)) ")
<< size;
}

Type elementTy = getType(operands[1]);
if (!elementTy) {
if (!elementTy)
return emitError(unknownLoc,
"OpTypeTensorARM references undefined element type.")
"OpTypeTensorARM references undefined element type ")
<< operands[1];
}

if (size == 2) {
typeMap[operands[0]] = TensorArmType::get({}, elementTy);
return success();
}

auto rankAttr = getConstantInt(operands[2]);
IntegerAttr rankAttr = getConstantInt(operands[2]);
if (!rankAttr)
return emitError(unknownLoc, "OpTypeTensorARM rank must come from a "
"scalar integer constant instruction");
Expand All @@ -1271,19 +1271,18 @@ spirv::Deserializer::processTensorARMType(ArrayRef<uint32_t> operands) {
return success();
}

auto shapeInfo = getConstant(operands[3]);
if (!shapeInfo) {
std::optional<std::pair<Attribute, Type>> shapeInfo = getConstant(operands[3]);
if (!shapeInfo)
return emitError(unknownLoc, "OpTypeTensorARM shape must come from a "
"constant instruction of type OpTypeArray");
}

ArrayAttr shapeArrayAttr = llvm::dyn_cast<ArrayAttr>(shapeInfo->first);
SmallVector<int64_t, 1> shape;
for (auto dimAttr : shapeArrayAttr.getValue()) {
auto dimIntAttr = llvm::dyn_cast<IntegerAttr>(dimAttr);
if (!dimIntAttr) {
if (!dimIntAttr)
return emitError(unknownLoc, "OpTypeTensorARM shape has an invalid "
"dimension size");
}
shape.push_back(dimIntAttr.getValue().getSExtValue());
}
typeMap[operands[0]] = TensorArmType::get(shape, elementTy);
Expand Down
Loading