@@ -576,34 +576,32 @@ convertProfileSummaryModuleFlagValue(ModuleOp mlirModule,
576576 return tupleEntry;
577577 };
578578
579- auto getFormat = [&](const llvm::MDOperand &formatMD) -> StringAttr {
579+ auto getFormat = [&](const llvm::MDOperand &formatMD)
580+ -> std::optional<ProfileSummaryFormatKind> {
580581 auto *tupleEntry = getMDTuple (formatMD);
581582 if (!tupleEntry)
582- return nullptr ;
583+ return std:: nullopt ;
583584
584585 llvm::MDString *keyMD = dyn_cast<llvm::MDString>(tupleEntry->getOperand (0 ));
585586 if (!keyMD || keyMD->getString () != " ProfileFormat" ) {
586587 emitWarning (mlirModule.getLoc ())
587588 << " expected 'ProfileFormat' key: "
588589 << diagMD (tupleEntry->getOperand (0 ), llvmModule);
589- return nullptr ;
590+ return std:: nullopt ;
590591 }
591592
592593 llvm::MDString *valMD = dyn_cast<llvm::MDString>(tupleEntry->getOperand (1 ));
593- auto formatAttr = llvm::StringSwitch<std::string>(valMD->getString ())
594- .Case (" SampleProfile" , " SampleProfile" )
595- .Case (" InstrProf" , " InstrProf" )
596- .Case (" CSInstrProf" , " CSInstrProf" )
597- .Default (" " );
598- if (formatAttr.empty ()) {
594+ std::optional<ProfileSummaryFormatKind> fmtKind =
595+ symbolizeProfileSummaryFormatKind (valMD->getString ());
596+ if (!fmtKind) {
599597 emitWarning (mlirModule.getLoc ())
600598 << " expected 'SampleProfile', 'InstrProf' or 'CSInstrProf' values, "
601599 " but found: "
602600 << diagMD (valMD, llvmModule);
603- return nullptr ;
601+ return std:: nullopt ;
604602 }
605603
606- return StringAttr::get (mlirModule-> getContext (), formatAttr) ;
604+ return fmtKind ;
607605 };
608606
609607 auto getConstantMD = [&](const llvm::MDOperand &md, StringRef matchKey,
@@ -746,8 +744,9 @@ convertProfileSummaryModuleFlagValue(ModuleOp mlirModule,
746744 // Build ModuleFlagProfileSummaryAttr by sequentially fetching elements in
747745 // a fixed order: format, total count, etc.
748746 SmallVector<Attribute> profileSummary;
749- StringAttr format = getFormat (mdTuple->getOperand (summayIdx++));
750- if (!format)
747+ std::optional<ProfileSummaryFormatKind> format =
748+ getFormat (mdTuple->getOperand (summayIdx++));
749+ if (!format.has_value ())
751750 return nullptr ;
752751
753752 uint64_t totalCount = 0 , maxCount = 0 , maxInternalCount = 0 ,
@@ -793,7 +792,7 @@ convertProfileSummaryModuleFlagValue(ModuleOp mlirModule,
793792
794793 // Build the final profile summary attribute.
795794 return ModuleFlagProfileSummaryAttr::get (
796- mlirModule->getContext (), format, totalCount, maxCount, maxInternalCount,
795+ mlirModule->getContext (), * format, totalCount, maxCount, maxInternalCount,
797796 maxFunctionCount, numCounts, numFunctions, isPartialProfile,
798797 partialProfileRatio ? partialProfileRatio : nullptr , detailedSummary);
799798}
0 commit comments