Skip to content

Commit 6f203a2

Browse files
committed
Use FailureOr
1 parent 0eff0a4 commit 6f203a2

File tree

1 file changed

+55
-50
lines changed

1 file changed

+55
-50
lines changed

mlir/lib/Target/LLVMIR/ModuleImport.cpp

Lines changed: 55 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -570,7 +570,7 @@ static llvm::ConstantAsMetadata *getConstantMDFromKeyValueTuple(
570570
llvm::MDTuple *tupleEntry = getTwoElementMDTuple(mlirModule, llvmModule, md);
571571
if (!tupleEntry)
572572
return nullptr;
573-
llvm::MDString *keyMD = dyn_cast<llvm::MDString>(tupleEntry->getOperand(0));
573+
auto *keyMD = dyn_cast<llvm::MDString>(tupleEntry->getOperand(0));
574574
if (!keyMD || keyMD->getString() != matchKey) {
575575
if (!optional)
576576
emitWarning(mlirModule.getLoc())
@@ -582,24 +582,22 @@ static llvm::ConstantAsMetadata *getConstantMDFromKeyValueTuple(
582582
return dyn_cast<llvm::ConstantAsMetadata>(tupleEntry->getOperand(1));
583583
}
584584

585-
static bool convertInt64FromKeyValueTuple(ModuleOp mlirModule,
586-
const llvm::Module *llvmModule,
587-
const llvm::MDOperand &md,
588-
StringRef matchKey, uint64_t &val) {
585+
static FailureOr<uint64_t>
586+
convertInt64FromKeyValueTuple(ModuleOp mlirModule,
587+
const llvm::Module *llvmModule,
588+
const llvm::MDOperand &md, StringRef matchKey) {
589589
auto *valMD =
590590
getConstantMDFromKeyValueTuple(mlirModule, llvmModule, md, matchKey);
591591
if (!valMD)
592-
return false;
592+
return failure();
593593

594-
if (auto *cstInt = dyn_cast<llvm::ConstantInt>(valMD->getValue())) {
595-
val = cstInt->getZExtValue();
596-
return true;
597-
}
594+
if (auto *cstInt = dyn_cast<llvm::ConstantInt>(valMD->getValue()))
595+
return cstInt->getZExtValue();
598596

599597
emitWarning(mlirModule.getLoc())
600598
<< "expected integer metadata value for key '" << matchKey
601599
<< "': " << diagMD(md, llvmModule);
602-
return false;
600+
return failure();
603601
}
604602

605603
static std::optional<ProfileSummaryFormatKind>
@@ -631,37 +629,38 @@ convertProfileSummaryFormat(ModuleOp mlirModule, const llvm::Module *llvmModule,
631629
return fmtKind;
632630
}
633631

634-
static bool convertProfileSummaryDetailed(
635-
ModuleOp mlirModule, const llvm::Module *llvmModule,
636-
const llvm::MDOperand &summaryMD,
637-
SmallVectorImpl<ModuleFlagProfileSummaryDetailedAttr> &detailedSummary) {
632+
static FailureOr<SmallVector<ModuleFlagProfileSummaryDetailedAttr>>
633+
convertProfileSummaryDetailed(ModuleOp mlirModule,
634+
const llvm::Module *llvmModule,
635+
const llvm::MDOperand &summaryMD) {
638636
auto *tupleEntry = getTwoElementMDTuple(mlirModule, llvmModule, summaryMD);
639637
if (!tupleEntry)
640-
return false;
638+
return failure();
641639

642640
llvm::MDString *keyMD = dyn_cast<llvm::MDString>(tupleEntry->getOperand(0));
643641
if (!keyMD || keyMD->getString() != "DetailedSummary") {
644642
emitWarning(mlirModule.getLoc())
645643
<< "expected 'DetailedSummary' key: "
646644
<< diagMD(tupleEntry->getOperand(0), llvmModule);
647-
return false;
645+
return failure();
648646
}
649647

650648
llvm::MDTuple *entriesMD = dyn_cast<llvm::MDTuple>(tupleEntry->getOperand(1));
651649
if (!entriesMD) {
652650
emitWarning(mlirModule.getLoc())
653651
<< "expected tuple value for 'DetailedSummary' key: "
654652
<< diagMD(tupleEntry->getOperand(1), llvmModule);
655-
return false;
653+
return failure();
656654
}
657655

656+
SmallVector<ModuleFlagProfileSummaryDetailedAttr> detailedSummary;
658657
for (auto &&entry : entriesMD->operands()) {
659658
llvm::MDTuple *entryMD = dyn_cast<llvm::MDTuple>(entry);
660659
if (!entryMD || entryMD->getNumOperands() != 3) {
661660
emitWarning(mlirModule.getLoc())
662661
<< "'DetailedSummary' entry expects 3 operands: "
663662
<< diagMD(entry, llvmModule);
664-
return false;
663+
return failure();
665664
}
666665
llvm::ConstantAsMetadata *op0 =
667666
dyn_cast<llvm::ConstantAsMetadata>(entryMD->getOperand(0));
@@ -674,7 +673,7 @@ static bool convertProfileSummaryDetailed(
674673
emitWarning(mlirModule.getLoc())
675674
<< "expected only integer entries in 'DetailedSummary': "
676675
<< diagMD(entry, llvmModule);
677-
return false;
676+
return failure();
678677
}
679678

680679
auto detaildSummaryEntry = ModuleFlagProfileSummaryDetailedAttr::get(
@@ -684,7 +683,7 @@ static bool convertProfileSummaryDetailed(
684683
cast<llvm::ConstantInt>(op2->getValue())->getZExtValue());
685684
detailedSummary.push_back(detaildSummaryEntry);
686685
}
687-
return true;
686+
return detailedSummary;
688687
}
689688

690689
static Attribute
@@ -722,9 +721,9 @@ convertProfileSummaryModuleFlagValue(ModuleOp mlirModule,
722721
return success();
723722
if (checkOptionalPosition(md, matchKey).failed())
724723
return failure();
725-
uint64_t tmpVal = 0;
726-
if (!convertInt64FromKeyValueTuple(mlirModule, llvmModule, md, matchKey,
727-
tmpVal))
724+
FailureOr<uint64_t> tmpVal =
725+
convertInt64FromKeyValueTuple(mlirModule, llvmModule, md, matchKey);
726+
if (failed(tmpVal))
728727
return failure();
729728
val = tmpVal;
730729
return success();
@@ -757,31 +756,36 @@ convertProfileSummaryModuleFlagValue(ModuleOp mlirModule,
757756
if (!format.has_value())
758757
return nullptr;
759758

760-
uint64_t totalCount = 0, maxCount = 0, maxInternalCount = 0,
761-
maxFunctionCount = 0, numCounts = 0, numFunctions = 0;
762-
if (!convertInt64FromKeyValueTuple(mlirModule, llvmModule,
763-
mdTuple->getOperand(summayIdx++),
764-
"TotalCount", totalCount))
759+
FailureOr<uint64_t> totalCount = convertInt64FromKeyValueTuple(
760+
mlirModule, llvmModule, mdTuple->getOperand(summayIdx++), "TotalCount");
761+
if (failed(totalCount))
765762
return nullptr;
766-
if (!convertInt64FromKeyValueTuple(mlirModule, llvmModule,
767-
mdTuple->getOperand(summayIdx++),
768-
"MaxCount", maxCount))
763+
764+
FailureOr<uint64_t> maxCount = convertInt64FromKeyValueTuple(
765+
mlirModule, llvmModule, mdTuple->getOperand(summayIdx++), "MaxCount");
766+
if (failed(maxCount))
769767
return nullptr;
770-
if (!convertInt64FromKeyValueTuple(mlirModule, llvmModule,
771-
mdTuple->getOperand(summayIdx++),
772-
"MaxInternalCount", maxInternalCount))
768+
769+
FailureOr<uint64_t> maxInternalCount = convertInt64FromKeyValueTuple(
770+
mlirModule, llvmModule, mdTuple->getOperand(summayIdx++),
771+
"MaxInternalCount");
772+
if (failed(maxInternalCount))
773773
return nullptr;
774-
if (!convertInt64FromKeyValueTuple(mlirModule, llvmModule,
775-
mdTuple->getOperand(summayIdx++),
776-
"MaxFunctionCount", maxFunctionCount))
774+
775+
FailureOr<uint64_t> maxFunctionCount = convertInt64FromKeyValueTuple(
776+
mlirModule, llvmModule, mdTuple->getOperand(summayIdx++),
777+
"MaxFunctionCount");
778+
if (failed(maxFunctionCount))
777779
return nullptr;
778-
if (!convertInt64FromKeyValueTuple(mlirModule, llvmModule,
779-
mdTuple->getOperand(summayIdx++),
780-
"NumCounts", numCounts))
780+
781+
FailureOr<uint64_t> numCounts = convertInt64FromKeyValueTuple(
782+
mlirModule, llvmModule, mdTuple->getOperand(summayIdx++), "NumCounts");
783+
if (failed(numCounts))
781784
return nullptr;
782-
if (!convertInt64FromKeyValueTuple(mlirModule, llvmModule,
783-
mdTuple->getOperand(summayIdx++),
784-
"NumFunctions", numFunctions))
785+
786+
FailureOr<uint64_t> numFunctions = convertInt64FromKeyValueTuple(
787+
mlirModule, llvmModule, mdTuple->getOperand(summayIdx++), "NumFunctions");
788+
if (failed(numFunctions))
785789
return nullptr;
786790

787791
// Handle optional keys.
@@ -802,16 +806,17 @@ convertProfileSummaryModuleFlagValue(ModuleOp mlirModule,
802806
summayIdx++;
803807

804808
// Handle detailed summary.
805-
SmallVector<ModuleFlagProfileSummaryDetailedAttr> detailed;
806-
if (!convertProfileSummaryDetailed(mlirModule, llvmModule,
807-
mdTuple->getOperand(summayIdx), detailed))
809+
FailureOr<SmallVector<ModuleFlagProfileSummaryDetailedAttr>> detailed =
810+
convertProfileSummaryDetailed(mlirModule, llvmModule,
811+
mdTuple->getOperand(summayIdx));
812+
if (failed(detailed))
808813
return nullptr;
809814

810815
// Build the final profile summary attribute.
811816
return ModuleFlagProfileSummaryAttr::get(
812-
mlirModule->getContext(), *format, totalCount, maxCount, maxInternalCount,
813-
maxFunctionCount, numCounts, numFunctions, isPartialProfile,
814-
partialProfileRatio ? partialProfileRatio : nullptr, detailed);
817+
mlirModule->getContext(), *format, *totalCount, *maxCount,
818+
*maxInternalCount, *maxFunctionCount, *numCounts, *numFunctions,
819+
isPartialProfile, partialProfileRatio, *detailed);
815820
}
816821

817822
/// Invoke specific handlers for each known module flag value, returns nullptr

0 commit comments

Comments
 (0)