@@ -514,10 +514,9 @@ bool ScalarType::isValid(IntegerType type) {
514514
515515void ScalarType::getExtensions (SPIRVType::ExtensionArrayRefVector &extensions,
516516 std::optional<StorageClass> storage) {
517- if (llvm::isa<BFloat16Type>(*this )) {
518- static const Extension exts[] = {Extension::SPV_KHR_bfloat16};
519- ArrayRef<Extension> ref (exts, std::size (exts));
520- extensions.push_back (ref);
517+ if (isa<BFloat16Type>(*this )) {
518+ static const Extension ext = Extension::SPV_KHR_bfloat16;
519+ extensions.push_back (ext);
521520 }
522521
523522 // 8- or 16-bit integer/floating-point numbers will require extra extensions
@@ -538,7 +537,7 @@ void ScalarType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
538537 [[fallthrough]];
539538 case StorageClass::Input:
540539 case StorageClass::Output:
541- if (getIntOrFloatBitWidth () == 16 && !llvm:: isa<BFloat16Type>(*this )) {
540+ if (getIntOrFloatBitWidth () == 16 && !isa<BFloat16Type>(*this )) {
542541 static const Extension exts[] = {Extension::SPV_KHR_16bit_storage};
543542 ArrayRef<Extension> ref (exts, std::size (exts));
544543 extensions.push_back (ref);
@@ -626,15 +625,12 @@ void ScalarType::getCapabilities(
626625 assert (llvm::isa<FloatType>(*this ));
627626 switch (bitwidth) {
628627 case 16 : {
629- if (llvm::isa<BFloat16Type>(*this )) {
630- static const Capability caps[] = {Capability::BFloat16TypeKHR};
631- ArrayRef<Capability> ref (caps, std::size (caps));
632- capabilities.push_back (ref);
633-
628+ if (isa<BFloat16Type>(*this )) {
629+ static const Capability cap = Capability::BFloat16TypeKHR;
630+ capabilities.push_back (cap);
634631 } else {
635- static const Capability caps[] = {Capability::Float16};
636- ArrayRef<Capability> ref (caps, std::size (caps));
637- capabilities.push_back (ref);
632+ static const Capability cap = Capability::Float16;
633+ capabilities.push_back (cap);
638634 }
639635 break ;
640636 }
0 commit comments