@@ -554,74 +554,152 @@ static Attribute convertCGProfileModuleFlagValue(ModuleOp mlirModule,
554554 return ArrayAttr::get (mlirModule->getContext (), cgProfile);
555555}
556556
557- static Attribute
558- convertProfileSummaryModuleFlagValue (ModuleOp mlirModule,
559- const llvm::Module *llvmModule,
560- llvm::MDTuple *mdTuple) {
561- unsigned profileNumEntries = mdTuple->getNumOperands ();
562- if (profileNumEntries < 8 ) {
557+ static llvm::MDTuple *getTwoElementMDTuple (ModuleOp mlirModule,
558+ const llvm::Module *llvmModule,
559+ const llvm::MDOperand &md) {
560+ auto *tupleEntry = dyn_cast_or_null<llvm::MDTuple>(md);
561+ if (!tupleEntry || tupleEntry->getNumOperands () != 2 )
563562 emitWarning (mlirModule.getLoc ())
564- << " expected at 8 entries in 'ProfileSummary': "
565- << diagMD (mdTuple, llvmModule);
563+ << " expected 2-element tuple metadata: " << diagMD (md, llvmModule);
564+ return tupleEntry;
565+ }
566+
567+ static llvm::ConstantAsMetadata *getConstantMDFromKeyValueTuple (
568+ ModuleOp mlirModule, const llvm::Module *llvmModule,
569+ const llvm::MDOperand &md, StringRef matchKey, bool optional = false ) {
570+ llvm::MDTuple *tupleEntry = getTwoElementMDTuple (mlirModule, llvmModule, md);
571+ if (!tupleEntry)
572+ return nullptr ;
573+ llvm::MDString *keyMD = dyn_cast<llvm::MDString>(tupleEntry->getOperand (0 ));
574+ if (!keyMD || keyMD->getString () != matchKey) {
575+ if (!optional)
576+ emitWarning (mlirModule.getLoc ())
577+ << " expected '" << matchKey << " ' key, but found: "
578+ << diagMD (tupleEntry->getOperand (0 ), llvmModule);
566579 return nullptr ;
567580 }
568581
569- unsigned summayIdx = 0 ;
582+ return dyn_cast<llvm::ConstantAsMetadata>(tupleEntry->getOperand (1 ));
583+ }
570584
571- auto getMDTuple = [&](const llvm::MDOperand &md) {
572- auto *tupleEntry = dyn_cast_or_null<llvm::MDTuple>(md);
573- if (!tupleEntry || tupleEntry->getNumOperands () != 2 )
574- emitWarning (mlirModule.getLoc ())
575- << " expected 2-element tuple metadata: " << diagMD (md, llvmModule);
576- return tupleEntry;
577- };
585+ static bool convertInt64FromKeyValueTuple (ModuleOp mlirModule,
586+ const llvm::Module *llvmModule,
587+ const llvm::MDOperand &md,
588+ StringRef matchKey, uint64_t &val) {
589+ auto *valMD =
590+ getConstantMDFromKeyValueTuple (mlirModule, llvmModule, md, matchKey);
591+ if (!valMD)
592+ return false ;
578593
579- auto getFormat = [&](const llvm::MDOperand &formatMD)
580- -> std::optional<ProfileSummaryFormatKind> {
581- auto *tupleEntry = getMDTuple (formatMD);
582- if (!tupleEntry)
583- return std::nullopt ;
594+ if (auto *cstInt = dyn_cast<llvm::ConstantInt>(valMD->getValue ())) {
595+ val = cstInt->getZExtValue ();
596+ return true ;
597+ }
598+
599+ emitWarning (mlirModule.getLoc ())
600+ << " expected integer metadata value for key '" << matchKey
601+ << " ': " << diagMD (md, llvmModule);
602+ return false ;
603+ }
604+
605+ static std::optional<ProfileSummaryFormatKind>
606+ convertProfileSummaryFormat (ModuleOp mlirModule, const llvm::Module *llvmModule,
607+ const llvm::MDOperand &formatMD) {
608+ auto *tupleEntry = getTwoElementMDTuple (mlirModule, llvmModule, formatMD);
609+ if (!tupleEntry)
610+ return std::nullopt ;
611+
612+ llvm::MDString *keyMD = dyn_cast<llvm::MDString>(tupleEntry->getOperand (0 ));
613+ if (!keyMD || keyMD->getString () != " ProfileFormat" ) {
614+ emitWarning (mlirModule.getLoc ())
615+ << " expected 'ProfileFormat' key: "
616+ << diagMD (tupleEntry->getOperand (0 ), llvmModule);
617+ return std::nullopt ;
618+ }
619+
620+ llvm::MDString *valMD = dyn_cast<llvm::MDString>(tupleEntry->getOperand (1 ));
621+ std::optional<ProfileSummaryFormatKind> fmtKind =
622+ symbolizeProfileSummaryFormatKind (valMD->getString ());
623+ if (!fmtKind) {
624+ emitWarning (mlirModule.getLoc ())
625+ << " expected 'SampleProfile', 'InstrProf' or 'CSInstrProf' values, "
626+ " but found: "
627+ << diagMD (valMD, llvmModule);
628+ return std::nullopt ;
629+ }
630+
631+ return fmtKind;
632+ }
633+
634+ static bool convertProfileSummaryDetailed (
635+ ModuleOp mlirModule, const llvm::Module *llvmModule,
636+ const llvm::MDOperand &summaryMD,
637+ SmallVectorImpl<ModuleFlagProfileSummaryDetailedAttr> &detailedSummary) {
638+ auto *tupleEntry = getTwoElementMDTuple (mlirModule, llvmModule, summaryMD);
639+ if (!tupleEntry)
640+ return false ;
641+
642+ llvm::MDString *keyMD = dyn_cast<llvm::MDString>(tupleEntry->getOperand (0 ));
643+ if (!keyMD || keyMD->getString () != " DetailedSummary" ) {
644+ emitWarning (mlirModule.getLoc ())
645+ << " expected 'DetailedSummary' key: "
646+ << diagMD (tupleEntry->getOperand (0 ), llvmModule);
647+ return false ;
648+ }
584649
585- llvm::MDString *keyMD = dyn_cast<llvm::MDString>(tupleEntry->getOperand (0 ));
586- if (!keyMD || keyMD->getString () != " ProfileFormat" ) {
650+ llvm::MDTuple *entriesMD = dyn_cast<llvm::MDTuple>(tupleEntry->getOperand (1 ));
651+ if (!entriesMD) {
652+ emitWarning (mlirModule.getLoc ())
653+ << " expected tuple value for 'DetailedSummary' key: "
654+ << diagMD (tupleEntry->getOperand (1 ), llvmModule);
655+ return false ;
656+ }
657+
658+ for (auto &&entry : entriesMD->operands ()) {
659+ llvm::MDTuple *entryMD = dyn_cast<llvm::MDTuple>(entry);
660+ if (!entryMD || entryMD->getNumOperands () != 3 ) {
587661 emitWarning (mlirModule.getLoc ())
588- << " expected 'ProfileFormat' key : "
589- << diagMD (tupleEntry-> getOperand ( 0 ) , llvmModule);
590- return std:: nullopt ;
662+ << " 'DetailedSummary' entry expects 3 operands : "
663+ << diagMD (entry , llvmModule);
664+ return false ;
591665 }
592-
593- llvm::MDString *valMD = dyn_cast<llvm::MDString>(tupleEntry->getOperand (1 ));
594- std::optional<ProfileSummaryFormatKind> fmtKind =
595- symbolizeProfileSummaryFormatKind (valMD->getString ());
596- if (!fmtKind) {
666+ llvm::ConstantAsMetadata *op0 =
667+ dyn_cast<llvm::ConstantAsMetadata>(entryMD->getOperand (0 ));
668+ llvm::ConstantAsMetadata *op1 =
669+ dyn_cast<llvm::ConstantAsMetadata>(entryMD->getOperand (1 ));
670+ llvm::ConstantAsMetadata *op2 =
671+ dyn_cast<llvm::ConstantAsMetadata>(entryMD->getOperand (2 ));
672+
673+ if (!op0 || !op1 || !op2) {
597674 emitWarning (mlirModule.getLoc ())
598- << " expected 'SampleProfile', 'InstrProf' or 'CSInstrProf' values, "
599- " but found: "
600- << diagMD (valMD, llvmModule);
601- return std::nullopt ;
675+ << " expected only integer entries in 'DetailedSummary': "
676+ << diagMD (entry, llvmModule);
677+ return false ;
602678 }
603679
604- return fmtKind;
605- };
606-
607- auto getConstantMD = [&](const llvm::MDOperand &md, StringRef matchKey,
608- bool optional =
609- false ) -> llvm::ConstantAsMetadata * {
610- llvm::MDTuple *tupleEntry = getMDTuple (md);
611- if (!tupleEntry)
612- return nullptr ;
613- llvm::MDString *keyMD = dyn_cast<llvm::MDString>(tupleEntry->getOperand (0 ));
614- if (!keyMD || keyMD->getString () != matchKey) {
615- if (!optional)
616- emitWarning (mlirModule.getLoc ())
617- << " expected '" << matchKey << " ' key, but found: "
618- << diagMD (tupleEntry->getOperand (0 ), llvmModule);
619- return nullptr ;
620- }
680+ auto detaildSummaryEntry = ModuleFlagProfileSummaryDetailedAttr::get (
681+ mlirModule->getContext (),
682+ cast<llvm::ConstantInt>(op0->getValue ())->getZExtValue (),
683+ cast<llvm::ConstantInt>(op1->getValue ())->getZExtValue (),
684+ cast<llvm::ConstantInt>(op2->getValue ())->getZExtValue ());
685+ detailedSummary.push_back (detaildSummaryEntry);
686+ }
687+ return true ;
688+ }
621689
622- return dyn_cast<llvm::ConstantAsMetadata>(tupleEntry->getOperand (1 ));
623- };
690+ static Attribute
691+ convertProfileSummaryModuleFlagValue (ModuleOp mlirModule,
692+ const llvm::Module *llvmModule,
693+ llvm::MDTuple *mdTuple) {
694+ unsigned profileNumEntries = mdTuple->getNumOperands ();
695+ if (profileNumEntries < 8 ) {
696+ emitWarning (mlirModule.getLoc ())
697+ << " expected at 8 entries in 'ProfileSummary': "
698+ << diagMD (mdTuple, llvmModule);
699+ return nullptr ;
700+ }
624701
702+ unsigned summayIdx = 0 ;
625703 auto checkOptionalPosition = [&](const llvm::MDOperand &md,
626704 StringRef matchKey) -> LogicalResult {
627705 // Make sure we won't step over the bound of the array of summary entries.
@@ -637,39 +715,25 @@ convertProfileSummaryModuleFlagValue(ModuleOp mlirModule,
637715 return success ();
638716 };
639717
640- auto getInt64Value = [&](const llvm::MDOperand &md, StringRef matchKey,
641- uint64_t &val) {
642- auto *valMD = getConstantMD (md, matchKey);
643- if (!valMD)
644- return false ;
645-
646- if (auto *cstInt = dyn_cast<llvm::ConstantInt>(valMD->getValue ())) {
647- val = cstInt->getZExtValue ();
648- return true ;
649- }
650-
651- emitWarning (mlirModule.getLoc ())
652- << " expected integer metadata value for key '" << matchKey
653- << " ': " << diagMD (md, llvmModule);
654- return false ;
655- };
656-
657718 auto getOptIntValue = [&](const llvm::MDOperand &md, StringRef matchKey,
658719 std::optional<uint64_t > &val) -> LogicalResult {
659- if (!getConstantMD (md, matchKey, /* optional=*/ true ))
720+ if (!getConstantMDFromKeyValueTuple (mlirModule, llvmModule, md, matchKey,
721+ /* optional=*/ true ))
660722 return success ();
661723 if (checkOptionalPosition (md, matchKey).failed ())
662724 return failure ();
663725 uint64_t tmpVal = 0 ;
664- if (!getInt64Value (md, matchKey, tmpVal))
726+ if (!convertInt64FromKeyValueTuple (mlirModule, llvmModule, md, matchKey,
727+ tmpVal))
665728 return failure ();
666729 val = tmpVal;
667730 return success ();
668731 };
669732
670733 auto getOptDoubleValue = [&](const llvm::MDOperand &md, StringRef matchKey,
671734 FloatAttr &attr) -> LogicalResult {
672- auto *valMD = getConstantMD (md, matchKey, /* optional=*/ true );
735+ auto *valMD = getConstantMDFromKeyValueTuple (mlirModule, llvmModule, md,
736+ matchKey, /* optional=*/ true );
673737 if (!valMD)
674738 return success ();
675739 if (auto *cstFP = dyn_cast<llvm::ConstantFP>(valMD->getValue ())) {
@@ -685,87 +749,39 @@ convertProfileSummaryModuleFlagValue(ModuleOp mlirModule,
685749 return failure ();
686750 };
687751
688- auto getSummary = [&](const llvm::MDOperand &summaryMD,
689- SmallVectorImpl<ModuleFlagProfileSummaryDetailedAttr>
690- &detailedSummary) {
691- auto *tupleEntry = getMDTuple (summaryMD);
692- if (!tupleEntry)
693- return false ;
694-
695- llvm::MDString *keyMD = dyn_cast<llvm::MDString>(tupleEntry->getOperand (0 ));
696- if (!keyMD || keyMD->getString () != " DetailedSummary" ) {
697- emitWarning (mlirModule.getLoc ())
698- << " expected 'DetailedSummary' key: "
699- << diagMD (tupleEntry->getOperand (0 ), llvmModule);
700- return false ;
701- }
702-
703- llvm::MDTuple *entriesMD =
704- dyn_cast<llvm::MDTuple>(tupleEntry->getOperand (1 ));
705- if (!entriesMD) {
706- emitWarning (mlirModule.getLoc ())
707- << " expected tuple value for 'DetailedSummary' key: "
708- << diagMD (tupleEntry->getOperand (1 ), llvmModule);
709- return false ;
710- }
711-
712- for (auto &&entry : entriesMD->operands ()) {
713- llvm::MDTuple *entryMD = dyn_cast<llvm::MDTuple>(entry);
714- if (!entryMD || entryMD->getNumOperands () != 3 ) {
715- emitWarning (mlirModule.getLoc ())
716- << " 'DetailedSummary' entry expects 3 operands: "
717- << diagMD (entry, llvmModule);
718- return false ;
719- }
720- llvm::ConstantAsMetadata *op0 =
721- dyn_cast<llvm::ConstantAsMetadata>(entryMD->getOperand (0 ));
722- llvm::ConstantAsMetadata *op1 =
723- dyn_cast<llvm::ConstantAsMetadata>(entryMD->getOperand (1 ));
724- llvm::ConstantAsMetadata *op2 =
725- dyn_cast<llvm::ConstantAsMetadata>(entryMD->getOperand (2 ));
726-
727- if (!op0 || !op1 || !op2) {
728- emitWarning (mlirModule.getLoc ())
729- << " expected only integer entries in 'DetailedSummary': "
730- << diagMD (entry, llvmModule);
731- return false ;
732- }
733-
734- auto detaildSummaryEntry = ModuleFlagProfileSummaryDetailedAttr::get (
735- mlirModule->getContext (),
736- cast<llvm::ConstantInt>(op0->getValue ())->getZExtValue (),
737- cast<llvm::ConstantInt>(op1->getValue ())->getZExtValue (),
738- cast<llvm::ConstantInt>(op2->getValue ())->getZExtValue ());
739- detailedSummary.push_back (detaildSummaryEntry);
740- }
741- return true ;
742- };
743-
744752 // Build ModuleFlagProfileSummaryAttr by sequentially fetching elements in
745753 // a fixed order: format, total count, etc.
746754 SmallVector<Attribute> profileSummary;
747- std::optional<ProfileSummaryFormatKind> format =
748- getFormat ( mdTuple->getOperand (summayIdx++));
755+ std::optional<ProfileSummaryFormatKind> format = convertProfileSummaryFormat (
756+ mlirModule, llvmModule, mdTuple->getOperand (summayIdx++));
749757 if (!format.has_value ())
750758 return nullptr ;
751759
752760 uint64_t totalCount = 0 , maxCount = 0 , maxInternalCount = 0 ,
753761 maxFunctionCount = 0 , numCounts = 0 , numFunctions = 0 ;
754- if (!getInt64Value (mdTuple->getOperand (summayIdx++), " TotalCount" ,
755- totalCount))
762+ if (!convertInt64FromKeyValueTuple (mlirModule, llvmModule,
763+ mdTuple->getOperand (summayIdx++),
764+ " TotalCount" , totalCount))
756765 return nullptr ;
757- if (!getInt64Value (mdTuple->getOperand (summayIdx++), " MaxCount" , maxCount))
766+ if (!convertInt64FromKeyValueTuple (mlirModule, llvmModule,
767+ mdTuple->getOperand (summayIdx++),
768+ " MaxCount" , maxCount))
758769 return nullptr ;
759- if (!getInt64Value (mdTuple->getOperand (summayIdx++), " MaxInternalCount" ,
760- maxInternalCount))
770+ if (!convertInt64FromKeyValueTuple (mlirModule, llvmModule,
771+ mdTuple->getOperand (summayIdx++),
772+ " MaxInternalCount" , maxInternalCount))
761773 return nullptr ;
762- if (!getInt64Value (mdTuple->getOperand (summayIdx++), " MaxFunctionCount" ,
763- maxFunctionCount))
774+ if (!convertInt64FromKeyValueTuple (mlirModule, llvmModule,
775+ mdTuple->getOperand (summayIdx++),
776+ " MaxFunctionCount" , maxFunctionCount))
764777 return nullptr ;
765- if (!getInt64Value (mdTuple->getOperand (summayIdx++), " NumCounts" , numCounts))
778+ if (!convertInt64FromKeyValueTuple (mlirModule, llvmModule,
779+ mdTuple->getOperand (summayIdx++),
780+ " NumCounts" , numCounts))
766781 return nullptr ;
767- if (!getInt64Value (mdTuple->getOperand (summayIdx++), " NumFunctions" ,
768- numFunctions))
782+ if (!convertInt64FromKeyValueTuple (mlirModule, llvmModule,
783+ mdTuple->getOperand (summayIdx++),
784+ " NumFunctions" , numFunctions))
769785 return nullptr ;
770786
771787 // Handle optional keys.
@@ -786,15 +802,16 @@ convertProfileSummaryModuleFlagValue(ModuleOp mlirModule,
786802 summayIdx++;
787803
788804 // Handle detailed summary.
789- SmallVector<ModuleFlagProfileSummaryDetailedAttr> detailedSummary;
790- if (!getSummary (mdTuple->getOperand (summayIdx), detailedSummary))
805+ SmallVector<ModuleFlagProfileSummaryDetailedAttr> detailed;
806+ if (!convertProfileSummaryDetailed (mlirModule, llvmModule,
807+ mdTuple->getOperand (summayIdx), detailed))
791808 return nullptr ;
792809
793810 // Build the final profile summary attribute.
794811 return ModuleFlagProfileSummaryAttr::get (
795812 mlirModule->getContext (), *format, totalCount, maxCount, maxInternalCount,
796813 maxFunctionCount, numCounts, numFunctions, isPartialProfile,
797- partialProfileRatio ? partialProfileRatio : nullptr , detailedSummary );
814+ partialProfileRatio ? partialProfileRatio : nullptr , detailed );
798815}
799816
800817// / Invoke specific handlers for each known module flag value, returns nullptr
0 commit comments