Skip to content

Commit 0eff0a4

Browse files
committed
move some lambdas to static local functions
1 parent d81fcce commit 0eff0a4

File tree

1 file changed

+162
-145
lines changed

1 file changed

+162
-145
lines changed

mlir/lib/Target/LLVMIR/ModuleImport.cpp

Lines changed: 162 additions & 145 deletions
Original file line numberDiff line numberDiff line change
@@ -554,74 +554,152 @@ static Attribute convertCGProfileModuleFlagValue(ModuleOp mlirModule,
554554
return ArrayAttr::get(mlirModule->getContext(), cgProfile);
555555
}
556556

557-
static Attribute
558-
convertProfileSummaryModuleFlagValue(ModuleOp mlirModule,
559-
const llvm::Module *llvmModule,
560-
llvm::MDTuple *mdTuple) {
561-
unsigned profileNumEntries = mdTuple->getNumOperands();
562-
if (profileNumEntries < 8) {
557+
static llvm::MDTuple *getTwoElementMDTuple(ModuleOp mlirModule,
558+
const llvm::Module *llvmModule,
559+
const llvm::MDOperand &md) {
560+
auto *tupleEntry = dyn_cast_or_null<llvm::MDTuple>(md);
561+
if (!tupleEntry || tupleEntry->getNumOperands() != 2)
563562
emitWarning(mlirModule.getLoc())
564-
<< "expected at 8 entries in 'ProfileSummary': "
565-
<< diagMD(mdTuple, llvmModule);
563+
<< "expected 2-element tuple metadata: " << diagMD(md, llvmModule);
564+
return tupleEntry;
565+
}
566+
567+
static llvm::ConstantAsMetadata *getConstantMDFromKeyValueTuple(
568+
ModuleOp mlirModule, const llvm::Module *llvmModule,
569+
const llvm::MDOperand &md, StringRef matchKey, bool optional = false) {
570+
llvm::MDTuple *tupleEntry = getTwoElementMDTuple(mlirModule, llvmModule, md);
571+
if (!tupleEntry)
572+
return nullptr;
573+
llvm::MDString *keyMD = dyn_cast<llvm::MDString>(tupleEntry->getOperand(0));
574+
if (!keyMD || keyMD->getString() != matchKey) {
575+
if (!optional)
576+
emitWarning(mlirModule.getLoc())
577+
<< "expected '" << matchKey << "' key, but found: "
578+
<< diagMD(tupleEntry->getOperand(0), llvmModule);
566579
return nullptr;
567580
}
568581

569-
unsigned summayIdx = 0;
582+
return dyn_cast<llvm::ConstantAsMetadata>(tupleEntry->getOperand(1));
583+
}
570584

571-
auto getMDTuple = [&](const llvm::MDOperand &md) {
572-
auto *tupleEntry = dyn_cast_or_null<llvm::MDTuple>(md);
573-
if (!tupleEntry || tupleEntry->getNumOperands() != 2)
574-
emitWarning(mlirModule.getLoc())
575-
<< "expected 2-element tuple metadata: " << diagMD(md, llvmModule);
576-
return tupleEntry;
577-
};
585+
static bool convertInt64FromKeyValueTuple(ModuleOp mlirModule,
586+
const llvm::Module *llvmModule,
587+
const llvm::MDOperand &md,
588+
StringRef matchKey, uint64_t &val) {
589+
auto *valMD =
590+
getConstantMDFromKeyValueTuple(mlirModule, llvmModule, md, matchKey);
591+
if (!valMD)
592+
return false;
578593

579-
auto getFormat = [&](const llvm::MDOperand &formatMD)
580-
-> std::optional<ProfileSummaryFormatKind> {
581-
auto *tupleEntry = getMDTuple(formatMD);
582-
if (!tupleEntry)
583-
return std::nullopt;
594+
if (auto *cstInt = dyn_cast<llvm::ConstantInt>(valMD->getValue())) {
595+
val = cstInt->getZExtValue();
596+
return true;
597+
}
598+
599+
emitWarning(mlirModule.getLoc())
600+
<< "expected integer metadata value for key '" << matchKey
601+
<< "': " << diagMD(md, llvmModule);
602+
return false;
603+
}
604+
605+
static std::optional<ProfileSummaryFormatKind>
606+
convertProfileSummaryFormat(ModuleOp mlirModule, const llvm::Module *llvmModule,
607+
const llvm::MDOperand &formatMD) {
608+
auto *tupleEntry = getTwoElementMDTuple(mlirModule, llvmModule, formatMD);
609+
if (!tupleEntry)
610+
return std::nullopt;
611+
612+
llvm::MDString *keyMD = dyn_cast<llvm::MDString>(tupleEntry->getOperand(0));
613+
if (!keyMD || keyMD->getString() != "ProfileFormat") {
614+
emitWarning(mlirModule.getLoc())
615+
<< "expected 'ProfileFormat' key: "
616+
<< diagMD(tupleEntry->getOperand(0), llvmModule);
617+
return std::nullopt;
618+
}
619+
620+
llvm::MDString *valMD = dyn_cast<llvm::MDString>(tupleEntry->getOperand(1));
621+
std::optional<ProfileSummaryFormatKind> fmtKind =
622+
symbolizeProfileSummaryFormatKind(valMD->getString());
623+
if (!fmtKind) {
624+
emitWarning(mlirModule.getLoc())
625+
<< "expected 'SampleProfile', 'InstrProf' or 'CSInstrProf' values, "
626+
"but found: "
627+
<< diagMD(valMD, llvmModule);
628+
return std::nullopt;
629+
}
630+
631+
return fmtKind;
632+
}
633+
634+
static bool convertProfileSummaryDetailed(
635+
ModuleOp mlirModule, const llvm::Module *llvmModule,
636+
const llvm::MDOperand &summaryMD,
637+
SmallVectorImpl<ModuleFlagProfileSummaryDetailedAttr> &detailedSummary) {
638+
auto *tupleEntry = getTwoElementMDTuple(mlirModule, llvmModule, summaryMD);
639+
if (!tupleEntry)
640+
return false;
641+
642+
llvm::MDString *keyMD = dyn_cast<llvm::MDString>(tupleEntry->getOperand(0));
643+
if (!keyMD || keyMD->getString() != "DetailedSummary") {
644+
emitWarning(mlirModule.getLoc())
645+
<< "expected 'DetailedSummary' key: "
646+
<< diagMD(tupleEntry->getOperand(0), llvmModule);
647+
return false;
648+
}
584649

585-
llvm::MDString *keyMD = dyn_cast<llvm::MDString>(tupleEntry->getOperand(0));
586-
if (!keyMD || keyMD->getString() != "ProfileFormat") {
650+
llvm::MDTuple *entriesMD = dyn_cast<llvm::MDTuple>(tupleEntry->getOperand(1));
651+
if (!entriesMD) {
652+
emitWarning(mlirModule.getLoc())
653+
<< "expected tuple value for 'DetailedSummary' key: "
654+
<< diagMD(tupleEntry->getOperand(1), llvmModule);
655+
return false;
656+
}
657+
658+
for (auto &&entry : entriesMD->operands()) {
659+
llvm::MDTuple *entryMD = dyn_cast<llvm::MDTuple>(entry);
660+
if (!entryMD || entryMD->getNumOperands() != 3) {
587661
emitWarning(mlirModule.getLoc())
588-
<< "expected 'ProfileFormat' key: "
589-
<< diagMD(tupleEntry->getOperand(0), llvmModule);
590-
return std::nullopt;
662+
<< "'DetailedSummary' entry expects 3 operands: "
663+
<< diagMD(entry, llvmModule);
664+
return false;
591665
}
592-
593-
llvm::MDString *valMD = dyn_cast<llvm::MDString>(tupleEntry->getOperand(1));
594-
std::optional<ProfileSummaryFormatKind> fmtKind =
595-
symbolizeProfileSummaryFormatKind(valMD->getString());
596-
if (!fmtKind) {
666+
llvm::ConstantAsMetadata *op0 =
667+
dyn_cast<llvm::ConstantAsMetadata>(entryMD->getOperand(0));
668+
llvm::ConstantAsMetadata *op1 =
669+
dyn_cast<llvm::ConstantAsMetadata>(entryMD->getOperand(1));
670+
llvm::ConstantAsMetadata *op2 =
671+
dyn_cast<llvm::ConstantAsMetadata>(entryMD->getOperand(2));
672+
673+
if (!op0 || !op1 || !op2) {
597674
emitWarning(mlirModule.getLoc())
598-
<< "expected 'SampleProfile', 'InstrProf' or 'CSInstrProf' values, "
599-
"but found: "
600-
<< diagMD(valMD, llvmModule);
601-
return std::nullopt;
675+
<< "expected only integer entries in 'DetailedSummary': "
676+
<< diagMD(entry, llvmModule);
677+
return false;
602678
}
603679

604-
return fmtKind;
605-
};
606-
607-
auto getConstantMD = [&](const llvm::MDOperand &md, StringRef matchKey,
608-
bool optional =
609-
false) -> llvm::ConstantAsMetadata * {
610-
llvm::MDTuple *tupleEntry = getMDTuple(md);
611-
if (!tupleEntry)
612-
return nullptr;
613-
llvm::MDString *keyMD = dyn_cast<llvm::MDString>(tupleEntry->getOperand(0));
614-
if (!keyMD || keyMD->getString() != matchKey) {
615-
if (!optional)
616-
emitWarning(mlirModule.getLoc())
617-
<< "expected '" << matchKey << "' key, but found: "
618-
<< diagMD(tupleEntry->getOperand(0), llvmModule);
619-
return nullptr;
620-
}
680+
auto detaildSummaryEntry = ModuleFlagProfileSummaryDetailedAttr::get(
681+
mlirModule->getContext(),
682+
cast<llvm::ConstantInt>(op0->getValue())->getZExtValue(),
683+
cast<llvm::ConstantInt>(op1->getValue())->getZExtValue(),
684+
cast<llvm::ConstantInt>(op2->getValue())->getZExtValue());
685+
detailedSummary.push_back(detaildSummaryEntry);
686+
}
687+
return true;
688+
}
621689

622-
return dyn_cast<llvm::ConstantAsMetadata>(tupleEntry->getOperand(1));
623-
};
690+
static Attribute
691+
convertProfileSummaryModuleFlagValue(ModuleOp mlirModule,
692+
const llvm::Module *llvmModule,
693+
llvm::MDTuple *mdTuple) {
694+
unsigned profileNumEntries = mdTuple->getNumOperands();
695+
if (profileNumEntries < 8) {
696+
emitWarning(mlirModule.getLoc())
697+
<< "expected at 8 entries in 'ProfileSummary': "
698+
<< diagMD(mdTuple, llvmModule);
699+
return nullptr;
700+
}
624701

702+
unsigned summayIdx = 0;
625703
auto checkOptionalPosition = [&](const llvm::MDOperand &md,
626704
StringRef matchKey) -> LogicalResult {
627705
// Make sure we won't step over the bound of the array of summary entries.
@@ -637,39 +715,25 @@ convertProfileSummaryModuleFlagValue(ModuleOp mlirModule,
637715
return success();
638716
};
639717

640-
auto getInt64Value = [&](const llvm::MDOperand &md, StringRef matchKey,
641-
uint64_t &val) {
642-
auto *valMD = getConstantMD(md, matchKey);
643-
if (!valMD)
644-
return false;
645-
646-
if (auto *cstInt = dyn_cast<llvm::ConstantInt>(valMD->getValue())) {
647-
val = cstInt->getZExtValue();
648-
return true;
649-
}
650-
651-
emitWarning(mlirModule.getLoc())
652-
<< "expected integer metadata value for key '" << matchKey
653-
<< "': " << diagMD(md, llvmModule);
654-
return false;
655-
};
656-
657718
auto getOptIntValue = [&](const llvm::MDOperand &md, StringRef matchKey,
658719
std::optional<uint64_t> &val) -> LogicalResult {
659-
if (!getConstantMD(md, matchKey, /*optional=*/true))
720+
if (!getConstantMDFromKeyValueTuple(mlirModule, llvmModule, md, matchKey,
721+
/*optional=*/true))
660722
return success();
661723
if (checkOptionalPosition(md, matchKey).failed())
662724
return failure();
663725
uint64_t tmpVal = 0;
664-
if (!getInt64Value(md, matchKey, tmpVal))
726+
if (!convertInt64FromKeyValueTuple(mlirModule, llvmModule, md, matchKey,
727+
tmpVal))
665728
return failure();
666729
val = tmpVal;
667730
return success();
668731
};
669732

670733
auto getOptDoubleValue = [&](const llvm::MDOperand &md, StringRef matchKey,
671734
FloatAttr &attr) -> LogicalResult {
672-
auto *valMD = getConstantMD(md, matchKey, /*optional=*/true);
735+
auto *valMD = getConstantMDFromKeyValueTuple(mlirModule, llvmModule, md,
736+
matchKey, /*optional=*/true);
673737
if (!valMD)
674738
return success();
675739
if (auto *cstFP = dyn_cast<llvm::ConstantFP>(valMD->getValue())) {
@@ -685,87 +749,39 @@ convertProfileSummaryModuleFlagValue(ModuleOp mlirModule,
685749
return failure();
686750
};
687751

688-
auto getSummary = [&](const llvm::MDOperand &summaryMD,
689-
SmallVectorImpl<ModuleFlagProfileSummaryDetailedAttr>
690-
&detailedSummary) {
691-
auto *tupleEntry = getMDTuple(summaryMD);
692-
if (!tupleEntry)
693-
return false;
694-
695-
llvm::MDString *keyMD = dyn_cast<llvm::MDString>(tupleEntry->getOperand(0));
696-
if (!keyMD || keyMD->getString() != "DetailedSummary") {
697-
emitWarning(mlirModule.getLoc())
698-
<< "expected 'DetailedSummary' key: "
699-
<< diagMD(tupleEntry->getOperand(0), llvmModule);
700-
return false;
701-
}
702-
703-
llvm::MDTuple *entriesMD =
704-
dyn_cast<llvm::MDTuple>(tupleEntry->getOperand(1));
705-
if (!entriesMD) {
706-
emitWarning(mlirModule.getLoc())
707-
<< "expected tuple value for 'DetailedSummary' key: "
708-
<< diagMD(tupleEntry->getOperand(1), llvmModule);
709-
return false;
710-
}
711-
712-
for (auto &&entry : entriesMD->operands()) {
713-
llvm::MDTuple *entryMD = dyn_cast<llvm::MDTuple>(entry);
714-
if (!entryMD || entryMD->getNumOperands() != 3) {
715-
emitWarning(mlirModule.getLoc())
716-
<< "'DetailedSummary' entry expects 3 operands: "
717-
<< diagMD(entry, llvmModule);
718-
return false;
719-
}
720-
llvm::ConstantAsMetadata *op0 =
721-
dyn_cast<llvm::ConstantAsMetadata>(entryMD->getOperand(0));
722-
llvm::ConstantAsMetadata *op1 =
723-
dyn_cast<llvm::ConstantAsMetadata>(entryMD->getOperand(1));
724-
llvm::ConstantAsMetadata *op2 =
725-
dyn_cast<llvm::ConstantAsMetadata>(entryMD->getOperand(2));
726-
727-
if (!op0 || !op1 || !op2) {
728-
emitWarning(mlirModule.getLoc())
729-
<< "expected only integer entries in 'DetailedSummary': "
730-
<< diagMD(entry, llvmModule);
731-
return false;
732-
}
733-
734-
auto detaildSummaryEntry = ModuleFlagProfileSummaryDetailedAttr::get(
735-
mlirModule->getContext(),
736-
cast<llvm::ConstantInt>(op0->getValue())->getZExtValue(),
737-
cast<llvm::ConstantInt>(op1->getValue())->getZExtValue(),
738-
cast<llvm::ConstantInt>(op2->getValue())->getZExtValue());
739-
detailedSummary.push_back(detaildSummaryEntry);
740-
}
741-
return true;
742-
};
743-
744752
// Build ModuleFlagProfileSummaryAttr by sequentially fetching elements in
745753
// a fixed order: format, total count, etc.
746754
SmallVector<Attribute> profileSummary;
747-
std::optional<ProfileSummaryFormatKind> format =
748-
getFormat(mdTuple->getOperand(summayIdx++));
755+
std::optional<ProfileSummaryFormatKind> format = convertProfileSummaryFormat(
756+
mlirModule, llvmModule, mdTuple->getOperand(summayIdx++));
749757
if (!format.has_value())
750758
return nullptr;
751759

752760
uint64_t totalCount = 0, maxCount = 0, maxInternalCount = 0,
753761
maxFunctionCount = 0, numCounts = 0, numFunctions = 0;
754-
if (!getInt64Value(mdTuple->getOperand(summayIdx++), "TotalCount",
755-
totalCount))
762+
if (!convertInt64FromKeyValueTuple(mlirModule, llvmModule,
763+
mdTuple->getOperand(summayIdx++),
764+
"TotalCount", totalCount))
756765
return nullptr;
757-
if (!getInt64Value(mdTuple->getOperand(summayIdx++), "MaxCount", maxCount))
766+
if (!convertInt64FromKeyValueTuple(mlirModule, llvmModule,
767+
mdTuple->getOperand(summayIdx++),
768+
"MaxCount", maxCount))
758769
return nullptr;
759-
if (!getInt64Value(mdTuple->getOperand(summayIdx++), "MaxInternalCount",
760-
maxInternalCount))
770+
if (!convertInt64FromKeyValueTuple(mlirModule, llvmModule,
771+
mdTuple->getOperand(summayIdx++),
772+
"MaxInternalCount", maxInternalCount))
761773
return nullptr;
762-
if (!getInt64Value(mdTuple->getOperand(summayIdx++), "MaxFunctionCount",
763-
maxFunctionCount))
774+
if (!convertInt64FromKeyValueTuple(mlirModule, llvmModule,
775+
mdTuple->getOperand(summayIdx++),
776+
"MaxFunctionCount", maxFunctionCount))
764777
return nullptr;
765-
if (!getInt64Value(mdTuple->getOperand(summayIdx++), "NumCounts", numCounts))
778+
if (!convertInt64FromKeyValueTuple(mlirModule, llvmModule,
779+
mdTuple->getOperand(summayIdx++),
780+
"NumCounts", numCounts))
766781
return nullptr;
767-
if (!getInt64Value(mdTuple->getOperand(summayIdx++), "NumFunctions",
768-
numFunctions))
782+
if (!convertInt64FromKeyValueTuple(mlirModule, llvmModule,
783+
mdTuple->getOperand(summayIdx++),
784+
"NumFunctions", numFunctions))
769785
return nullptr;
770786

771787
// Handle optional keys.
@@ -786,15 +802,16 @@ convertProfileSummaryModuleFlagValue(ModuleOp mlirModule,
786802
summayIdx++;
787803

788804
// Handle detailed summary.
789-
SmallVector<ModuleFlagProfileSummaryDetailedAttr> detailedSummary;
790-
if (!getSummary(mdTuple->getOperand(summayIdx), detailedSummary))
805+
SmallVector<ModuleFlagProfileSummaryDetailedAttr> detailed;
806+
if (!convertProfileSummaryDetailed(mlirModule, llvmModule,
807+
mdTuple->getOperand(summayIdx), detailed))
791808
return nullptr;
792809

793810
// Build the final profile summary attribute.
794811
return ModuleFlagProfileSummaryAttr::get(
795812
mlirModule->getContext(), *format, totalCount, maxCount, maxInternalCount,
796813
maxFunctionCount, numCounts, numFunctions, isPartialProfile,
797-
partialProfileRatio ? partialProfileRatio : nullptr, detailedSummary);
814+
partialProfileRatio ? partialProfileRatio : nullptr, detailed);
798815
}
799816

800817
/// Invoke specific handlers for each known module flag value, returns nullptr

0 commit comments

Comments
 (0)