@@ -744,56 +744,26 @@ void OCLToSPIRVBase::visitCallConvert(CallInst *CI, StringRef MangledName,
744744 if (auto *VecTy = dyn_cast<VectorType>(SrcTy))
745745 SrcTy = VecTy->getElementType ();
746746 auto IsTargetInt = isa<IntegerType>(TargetTy);
747- auto TargetSigned = DemangledName[8 ] != ' u' ;
748747
749- std::string TargetTyName (
750- DemangledName.substr (strlen (kOCLBuiltinName ::ConvertPrefix)));
751- auto FirstUnderscoreLoc = TargetTyName.find (' _' );
752- if (FirstUnderscoreLoc != std::string::npos)
753- TargetTyName = TargetTyName.substr (0 , FirstUnderscoreLoc);
754-
755- // Validate target type name
756- std::regex Expr (" ([a-z]+)([0-9]*)$" );
748+ // Validate conversion function name and vector size if present
749+ std::regex Expr (
750+ " convert_(float|double|half|u?char|u?short|u?int|u?long)(2|3|4|8|16)*"
751+ " (_sat)*(_rt[ezpn])*$" );
757752 std::smatch DestTyMatch;
758- if (!std::regex_match (TargetTyName, DestTyMatch, Expr))
753+ std::string ConversionFunc (DemangledName.str ());
754+ if (!std::regex_match (ConversionFunc, DestTyMatch, Expr))
759755 return ;
760756
761757 // The first sub_match is the whole string; the next
762- // sub_match is the first parenthesized expression.
763- std::string DestTy = DestTyMatch[1 ].str ();
764-
765- // check it's valid type name
766- static std::unordered_set<std::string> ValidTypes = {
767- " float" , " double" , " half" , " char" , " uchar" , " short" ,
768- " ushort" , " int" , " uint" , " long" , " ulong" };
769-
770- if (ValidTypes.find (DestTy) == ValidTypes.end ())
771- return ;
772-
773- // check that it's allowed vector size
774- std::string VecSize = DestTyMatch[2 ].str ();
775- if (!VecSize.empty ()) {
776- int Size = stoi (VecSize);
777- switch (Size) {
778- case 2 :
779- case 3 :
780- case 4 :
781- case 8 :
782- case 16 :
783- break ;
784- default :
785- return ;
786- }
787- }
788- DemangledName = DemangledName.drop_front (
789- strlen (kOCLBuiltinName ::ConvertPrefix) + TargetTyName.size ());
790- TargetTyName = std::string (" _R" ) + TargetTyName;
758+ // sub_matches are the parenthesized expressions.
759+ enum { TypeIdx = 1 , VecSizeIdx = 2 , SatIdx = 3 , RoundingIdx = 4 };
760+ std::string DestTy = DestTyMatch[TypeIdx].str ();
761+ std::string VecSize = DestTyMatch[VecSizeIdx].str ();
762+ std::string Sat = DestTyMatch[SatIdx].str ();
763+ std::string Rounding = DestTyMatch[RoundingIdx].str ();
791764
792- if (!DemangledName.empty () && !DemangledName.starts_with (" _sat" ) &&
793- !DemangledName.starts_with (" _rt" ))
794- return ;
765+ bool TargetSigned = DestTy[0 ] != ' u' ;
795766
796- std::string Sat = DemangledName.find (" _sat" ) != StringRef::npos ? " _sat" : " " ;
797767 if (isa<IntegerType>(SrcTy)) {
798768 bool Signed = isLastFuncParamSigned (MangledName);
799769 if (IsTargetInt) {
@@ -810,13 +780,13 @@ void OCLToSPIRVBase::visitCallConvert(CallInst *CI, StringRef MangledName,
810780 } else
811781 OC = OpFConvert;
812782 }
813- auto Loc = DemangledName.find (" _rt" );
814- std::string Rounding;
815- if (Loc != StringRef::npos && !(isa<IntegerType>(SrcTy) && IsTargetInt)) {
816- Rounding = DemangledName.substr (Loc, 4 ).str ();
817- }
783+
784+ if (!Rounding.empty () && (isa<IntegerType>(SrcTy) && IsTargetInt))
785+ return ;
786+
818787 assert (CI->getCalledFunction () && " Unexpected indirect call" );
819- mutateCallInst (CI, getSPIRVFuncName (OC, TargetTyName + Sat + Rounding));
788+ mutateCallInst (
789+ CI, getSPIRVFuncName (OC, " _R" + DestTy + VecSize + Sat + Rounding));
820790}
821791
822792void OCLToSPIRVBase::visitCallGroupBuiltin (CallInst *CI,
0 commit comments