|
20 | 20 | #include "llvm/Support/ErrorHandling.h"
|
21 | 21 |
|
22 | 22 | #include <cstdint>
|
| 23 | +#include <optional> |
23 | 24 |
|
24 | 25 | using namespace mlir;
|
25 | 26 | using namespace mlir::spirv;
|
@@ -172,14 +173,6 @@ Type ArrayType::getElementType() const { return getImpl()->elementType; }
|
172 | 173 |
|
173 | 174 | unsigned ArrayType::getArrayStride() const { return getImpl()->stride; }
|
174 | 175 |
|
175 |
| -std::optional<int64_t> ArrayType::getSizeInBytes() { |
176 |
| - auto elementType = llvm::cast<SPIRVType>(getElementType()); |
177 |
| - std::optional<int64_t> size = elementType.getSizeInBytes(); |
178 |
| - if (!size) |
179 |
| - return std::nullopt; |
180 |
| - return (*size + getArrayStride()) * getNumElements(); |
181 |
| -} |
182 |
| - |
183 | 176 | //===----------------------------------------------------------------------===//
|
184 | 177 | // CompositeType
|
185 | 178 | //===----------------------------------------------------------------------===//
|
@@ -245,28 +238,6 @@ void TypeCapabilityVisitor::addConcrete(VectorType type) {
|
245 | 238 | }
|
246 | 239 | }
|
247 | 240 |
|
248 |
| -std::optional<int64_t> CompositeType::getSizeInBytes() { |
249 |
| - if (auto arrayType = llvm::dyn_cast<ArrayType>(*this)) |
250 |
| - return arrayType.getSizeInBytes(); |
251 |
| - if (auto structType = llvm::dyn_cast<StructType>(*this)) |
252 |
| - return structType.getSizeInBytes(); |
253 |
| - if (auto vectorType = llvm::dyn_cast<VectorType>(*this)) { |
254 |
| - std::optional<int64_t> elementSize = |
255 |
| - llvm::cast<ScalarType>(vectorType.getElementType()).getSizeInBytes(); |
256 |
| - if (!elementSize) |
257 |
| - return std::nullopt; |
258 |
| - return *elementSize * vectorType.getNumElements(); |
259 |
| - } |
260 |
| - if (auto tensorArmType = llvm::dyn_cast<TensorArmType>(*this)) { |
261 |
| - std::optional<int64_t> elementSize = |
262 |
| - llvm::cast<ScalarType>(tensorArmType.getElementType()).getSizeInBytes(); |
263 |
| - if (!elementSize) |
264 |
| - return std::nullopt; |
265 |
| - return *elementSize * tensorArmType.getNumElements(); |
266 |
| - } |
267 |
| - return std::nullopt; |
268 |
| -} |
269 |
| - |
270 | 241 | //===----------------------------------------------------------------------===//
|
271 | 242 | // CooperativeMatrixType
|
272 | 243 | //===----------------------------------------------------------------------===//
|
@@ -714,19 +685,6 @@ void TypeCapabilityVisitor::addConcrete(ScalarType type) {
|
714 | 685 | #undef WIDTH_CASE
|
715 | 686 | }
|
716 | 687 |
|
717 |
| -std::optional<int64_t> ScalarType::getSizeInBytes() { |
718 |
| - auto bitWidth = getIntOrFloatBitWidth(); |
719 |
| - // According to the SPIR-V spec: |
720 |
| - // "There is no physical size or bit pattern defined for values with boolean |
721 |
| - // type. If they are stored (in conjunction with OpVariable), they can only |
722 |
| - // be used with logical addressing operations, not physical, and only with |
723 |
| - // non-externally visible shader Storage Classes: Workgroup, CrossWorkgroup, |
724 |
| - // Private, Function, Input, and Output." |
725 |
| - if (bitWidth == 1) |
726 |
| - return std::nullopt; |
727 |
| - return bitWidth / 8; |
728 |
| -} |
729 |
| - |
730 | 688 | //===----------------------------------------------------------------------===//
|
731 | 689 | // SPIRVType
|
732 | 690 | //===----------------------------------------------------------------------===//
|
@@ -760,11 +718,35 @@ void SPIRVType::getCapabilities(
|
760 | 718 | }
|
761 | 719 |
|
762 | 720 | std::optional<int64_t> SPIRVType::getSizeInBytes() {
|
763 |
| - if (auto scalarType = llvm::dyn_cast<ScalarType>(*this)) |
764 |
| - return scalarType.getSizeInBytes(); |
765 |
| - if (auto compositeType = llvm::dyn_cast<CompositeType>(*this)) |
766 |
| - return compositeType.getSizeInBytes(); |
767 |
| - return std::nullopt; |
| 721 | + return TypeSwitch<SPIRVType, std::optional<int64_t>>(*this) |
| 722 | + .Case<ScalarType>([](ScalarType type) -> std::optional<int64_t> { |
| 723 | + // According to the SPIR-V spec: |
| 724 | + // "There is no physical size or bit pattern defined for values with |
| 725 | + // boolean type. If they are stored (in conjunction with OpVariable), |
| 726 | + // they can only be used with logical addressing operations, not |
| 727 | + // physical, and only with non-externally visible shader Storage |
| 728 | + // Classes: Workgroup, CrossWorkgroup, Private, Function, Input, and |
| 729 | + // Output." |
| 730 | + int64_t bitWidth = type.getIntOrFloatBitWidth(); |
| 731 | + if (bitWidth == 1) |
| 732 | + return std::nullopt; |
| 733 | + return bitWidth / 8; |
| 734 | + }) |
| 735 | + .Case<ArrayType>([](ArrayType type) -> std::optional<int64_t> { |
| 736 | + // Since array type may have an explicit stride declaration (in bytes), |
| 737 | + // we also include it in the calculation. |
| 738 | + auto elementType = cast<SPIRVType>(type.getElementType()); |
| 739 | + if (std::optional<int64_t> size = elementType.getSizeInBytes()) |
| 740 | + return (*size + type.getArrayStride()) * type.getNumElements(); |
| 741 | + return std::nullopt; |
| 742 | + }) |
| 743 | + .Case<VectorType, TensorArmType>([](auto type) -> std::optional<int64_t> { |
| 744 | + if (std::optional<int64_t> elementSize = |
| 745 | + cast<ScalarType>(type.getElementType()).getSizeInBytes()) |
| 746 | + return *elementSize * type.getNumElements(); |
| 747 | + return std::nullopt; |
| 748 | + }) |
| 749 | + .Default(std::optional<int64_t>()); |
768 | 750 | }
|
769 | 751 |
|
770 | 752 | //===----------------------------------------------------------------------===//
|
|
0 commit comments