Skip to content

Commit c526c70

Browse files
authored
[mlir][spirv] Rework type size calculation (#160162)
Similar to `::getExtensions` and `::getCapabilities`, introduce a single entry point for type size calculation. Also fix potential infinite recursion with `StructType`s (even non-recursive structs), although I don't know to write a test for this without using C++. This is mostly an NFC modulo this potential bug fix.
1 parent 8b824f3 commit c526c70

File tree

2 files changed

+30
-56
lines changed

2 files changed

+30
-56
lines changed

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,6 @@ class ScalarType : public SPIRVType {
8888
static bool isValid(FloatType);
8989
/// Returns true if the given float type is valid for the SPIR-V dialect.
9090
static bool isValid(IntegerType);
91-
92-
std::optional<int64_t> getSizeInBytes();
9391
};
9492

9593
// SPIR-V composite type: VectorType, SPIR-V ArrayType, SPIR-V
@@ -112,8 +110,6 @@ class CompositeType : public SPIRVType {
112110
/// Return true if the number of elements is known at compile time and is not
113111
/// implementation dependent.
114112
bool hasCompileTimeKnownNumElements() const;
115-
116-
std::optional<int64_t> getSizeInBytes();
117113
};
118114

119115
// SPIR-V array type
@@ -137,10 +133,6 @@ class ArrayType : public Type::TypeBase<ArrayType, CompositeType,
137133
/// Returns the array stride in bytes. 0 means no stride decorated on this
138134
/// type.
139135
unsigned getArrayStride() const;
140-
141-
/// Returns the array size in bytes. Since array type may have an explicit
142-
/// stride declaration (in bytes), we also include it in the calculation.
143-
std::optional<int64_t> getSizeInBytes();
144136
};
145137

146138
// SPIR-V image type

mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp

Lines changed: 30 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include "llvm/Support/ErrorHandling.h"
2121

2222
#include <cstdint>
23+
#include <optional>
2324

2425
using namespace mlir;
2526
using namespace mlir::spirv;
@@ -172,14 +173,6 @@ Type ArrayType::getElementType() const { return getImpl()->elementType; }
172173

173174
unsigned ArrayType::getArrayStride() const { return getImpl()->stride; }
174175

175-
std::optional<int64_t> ArrayType::getSizeInBytes() {
176-
auto elementType = llvm::cast<SPIRVType>(getElementType());
177-
std::optional<int64_t> size = elementType.getSizeInBytes();
178-
if (!size)
179-
return std::nullopt;
180-
return (*size + getArrayStride()) * getNumElements();
181-
}
182-
183176
//===----------------------------------------------------------------------===//
184177
// CompositeType
185178
//===----------------------------------------------------------------------===//
@@ -245,28 +238,6 @@ void TypeCapabilityVisitor::addConcrete(VectorType type) {
245238
}
246239
}
247240

248-
std::optional<int64_t> CompositeType::getSizeInBytes() {
249-
if (auto arrayType = llvm::dyn_cast<ArrayType>(*this))
250-
return arrayType.getSizeInBytes();
251-
if (auto structType = llvm::dyn_cast<StructType>(*this))
252-
return structType.getSizeInBytes();
253-
if (auto vectorType = llvm::dyn_cast<VectorType>(*this)) {
254-
std::optional<int64_t> elementSize =
255-
llvm::cast<ScalarType>(vectorType.getElementType()).getSizeInBytes();
256-
if (!elementSize)
257-
return std::nullopt;
258-
return *elementSize * vectorType.getNumElements();
259-
}
260-
if (auto tensorArmType = llvm::dyn_cast<TensorArmType>(*this)) {
261-
std::optional<int64_t> elementSize =
262-
llvm::cast<ScalarType>(tensorArmType.getElementType()).getSizeInBytes();
263-
if (!elementSize)
264-
return std::nullopt;
265-
return *elementSize * tensorArmType.getNumElements();
266-
}
267-
return std::nullopt;
268-
}
269-
270241
//===----------------------------------------------------------------------===//
271242
// CooperativeMatrixType
272243
//===----------------------------------------------------------------------===//
@@ -714,19 +685,6 @@ void TypeCapabilityVisitor::addConcrete(ScalarType type) {
714685
#undef WIDTH_CASE
715686
}
716687

717-
std::optional<int64_t> ScalarType::getSizeInBytes() {
718-
auto bitWidth = getIntOrFloatBitWidth();
719-
// According to the SPIR-V spec:
720-
// "There is no physical size or bit pattern defined for values with boolean
721-
// type. If they are stored (in conjunction with OpVariable), they can only
722-
// be used with logical addressing operations, not physical, and only with
723-
// non-externally visible shader Storage Classes: Workgroup, CrossWorkgroup,
724-
// Private, Function, Input, and Output."
725-
if (bitWidth == 1)
726-
return std::nullopt;
727-
return bitWidth / 8;
728-
}
729-
730688
//===----------------------------------------------------------------------===//
731689
// SPIRVType
732690
//===----------------------------------------------------------------------===//
@@ -760,11 +718,35 @@ void SPIRVType::getCapabilities(
760718
}
761719

762720
std::optional<int64_t> SPIRVType::getSizeInBytes() {
763-
if (auto scalarType = llvm::dyn_cast<ScalarType>(*this))
764-
return scalarType.getSizeInBytes();
765-
if (auto compositeType = llvm::dyn_cast<CompositeType>(*this))
766-
return compositeType.getSizeInBytes();
767-
return std::nullopt;
721+
return TypeSwitch<SPIRVType, std::optional<int64_t>>(*this)
722+
.Case<ScalarType>([](ScalarType type) -> std::optional<int64_t> {
723+
// According to the SPIR-V spec:
724+
// "There is no physical size or bit pattern defined for values with
725+
// boolean type. If they are stored (in conjunction with OpVariable),
726+
// they can only be used with logical addressing operations, not
727+
// physical, and only with non-externally visible shader Storage
728+
// Classes: Workgroup, CrossWorkgroup, Private, Function, Input, and
729+
// Output."
730+
int64_t bitWidth = type.getIntOrFloatBitWidth();
731+
if (bitWidth == 1)
732+
return std::nullopt;
733+
return bitWidth / 8;
734+
})
735+
.Case<ArrayType>([](ArrayType type) -> std::optional<int64_t> {
736+
// Since array type may have an explicit stride declaration (in bytes),
737+
// we also include it in the calculation.
738+
auto elementType = cast<SPIRVType>(type.getElementType());
739+
if (std::optional<int64_t> size = elementType.getSizeInBytes())
740+
return (*size + type.getArrayStride()) * type.getNumElements();
741+
return std::nullopt;
742+
})
743+
.Case<VectorType, TensorArmType>([](auto type) -> std::optional<int64_t> {
744+
if (std::optional<int64_t> elementSize =
745+
cast<ScalarType>(type.getElementType()).getSizeInBytes())
746+
return *elementSize * type.getNumElements();
747+
return std::nullopt;
748+
})
749+
.Default(std::optional<int64_t>());
768750
}
769751

770752
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)