@@ -696,7 +696,7 @@ static Type *parsePrimitiveType(LLVMContext &Ctx, StringRef Name) {
696696 .Cases (" long" , " unsigned long" , Type::getInt64Ty (Ctx))
697697 .Cases (" long long" , " unsigned long long" , Type::getInt64Ty (Ctx))
698698 .Case (" half" , Type::getHalfTy (Ctx))
699- .Case (" std::bfloat16_t" , Type::getBFloatTy (Ctx))
699+ .Cases (" std::bfloat16_t" , " __bf16 " , Type::getBFloatTy (Ctx))
700700 .Case (" float" , Type::getFloatTy (Ctx))
701701 .Case (" double" , Type::getDoubleTy (Ctx))
702702 .Case (" void" , Type::getInt8Ty (Ctx))
@@ -877,7 +877,8 @@ parseNode(Module *M, const llvm::itanium_demangle::Node *ParamType,
877877 }
878878 } else if (auto *VendorTy = dyn_cast<VendorExtQualType>(ParamType)) {
879879 if (auto *NameTy = dyn_cast<NameType>(VendorTy->getTy ())) {
880- if (NameTy->getName () == " std::bfloat16_t" )
880+ if (NameTy->getName () == " std::bfloat16_t" ||
881+ NameTy->getName () == " __bf16" )
881882 PointeeTy = llvm::Type::getBFloatTy (M->getContext ());
882883 }
883884 // This is a block parameter. Decode the pointee type as if it were a
@@ -937,13 +938,32 @@ bool getParameterTypes(Function *F, SmallVectorImpl<Type *> &ArgTys,
937938 if (HasSret)
938939 ++ArgIter;
939940
941+ // "DF<N>b" mangling for bfloat<N> types (e.g. DF16b for bfloat16) is
942+ // recognized by the demangler only starting from LLVM 20. Replace "DF16b"
943+ // in the parameter section with the vendor-extended-type encoding "u6__bf16",
944+ // which all known demangler versions parse correctly as NameType("__bf16").
945+ std::string PatchedName;
946+ StringRef MangledName (F->getName ());
947+ if (MangledName.contains (" DF16b" )) {
948+ PatchedName = MangledName.str ();
949+ // Skip "_Z<N><name>" to search only in the parameter section.
950+ size_t Start = PatchedName.find_first_not_of (" 0123456789" , 2 );
951+ size_t Len = 0 ;
952+ StringRef (PatchedName).substr (2 , Start - 2 ).getAsInteger (10 , Len);
953+ size_t Pos = Start + Len;
954+ while ((Pos = PatchedName.find (" DF16b" , Pos)) != std::string::npos) {
955+ PatchedName.replace (Pos, 5 , " u6__bf16" );
956+ Pos += 8 ;
957+ }
958+ MangledName = PatchedName;
959+ }
960+
940961 // Demangle the function arguments. If we get an input name of
941962 // "_Z12write_imagei20ocl_image1d_array_woDv2_iiDv4_i", then we expect
942963 // that Demangler.getFunctionParameters will return
943964 // "(ocl_image1d_array_wo, int __vector(2), int, int __vector(4))" (in other
944965 // words, the stuff between the parentheses if you ran C++ filt, including
945966 // the parentheses itself).
946- const StringRef MangledName (F->getName ());
947967 ManglingParser<DefaultAllocator> Demangler (MangledName.begin (),
948968 MangledName.end ());
949969 // We expect to see only function name encodings here. If it's not a function
0 commit comments