Skip to content

Commit e01695b

Browse files
committed
Address review comments
1 parent c0fee39 commit e01695b

File tree

2 files changed

+12
-15
lines changed

2 files changed

+12
-15
lines changed

mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -514,10 +514,9 @@ bool ScalarType::isValid(IntegerType type) {
514514

515515
void ScalarType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
516516
std::optional<StorageClass> storage) {
517-
if (llvm::isa<BFloat16Type>(*this)) {
518-
static const Extension exts[] = {Extension::SPV_KHR_bfloat16};
519-
ArrayRef<Extension> ref(exts, std::size(exts));
520-
extensions.push_back(ref);
517+
if (isa<BFloat16Type>(*this)) {
518+
static const Extension ext = Extension::SPV_KHR_bfloat16;
519+
extensions.push_back(ext);
521520
}
522521

523522
// 8- or 16-bit integer/floating-point numbers will require extra extensions
@@ -538,7 +537,7 @@ void ScalarType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
538537
[[fallthrough]];
539538
case StorageClass::Input:
540539
case StorageClass::Output:
541-
if (getIntOrFloatBitWidth() == 16 && !llvm::isa<BFloat16Type>(*this)) {
540+
if (getIntOrFloatBitWidth() == 16 && !isa<BFloat16Type>(*this)) {
542541
static const Extension exts[] = {Extension::SPV_KHR_16bit_storage};
543542
ArrayRef<Extension> ref(exts, std::size(exts));
544543
extensions.push_back(ref);
@@ -626,15 +625,12 @@ void ScalarType::getCapabilities(
626625
assert(llvm::isa<FloatType>(*this));
627626
switch (bitwidth) {
628627
case 16: {
629-
if (llvm::isa<BFloat16Type>(*this)) {
630-
static const Capability caps[] = {Capability::BFloat16TypeKHR};
631-
ArrayRef<Capability> ref(caps, std::size(caps));
632-
capabilities.push_back(ref);
633-
628+
if (isa<BFloat16Type>(*this)) {
629+
static const Capability cap = Capability::BFloat16TypeKHR;
630+
capabilities.push_back(cap);
634631
} else {
635-
static const Capability caps[] = {Capability::Float16};
636-
ArrayRef<Capability> ref(caps, std::size(caps));
637-
capabilities.push_back(ref);
632+
static const Capability cap = Capability::Float16;
633+
capabilities.push_back(cap);
638634
}
639635
break;
640636
}

mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -868,8 +868,9 @@ LogicalResult spirv::Deserializer::processType(spirv::Opcode opcode,
868868
} break;
869869
case spirv::Opcode::OpTypeFloat: {
870870
if (operands.size() != 2 && operands.size() != 3)
871-
return emitError(unknownLoc, "OpTypeFloat must have bitwidth parameter "
872-
"and optional floating point encoding");
871+
return emitError(unknownLoc,
872+
"OpTypeFloat expects either 2 operands (type, bitwidth) "
873+
"or 3 operands (type, bitwidth, encoding)");
873874
uint32_t bitWidth = operands[1];
874875

875876
Type floatTy;

0 commit comments

Comments
 (0)