From 66bba757ef8e7ad43a10e80a4540ba7bb51ecdac Mon Sep 17 00:00:00 2001 From: Bruno Cardoso Lopes Date: Fri, 25 Apr 2025 14:45:11 -0700 Subject: [PATCH 01/14] [MLIR][LLVM] Add ProfileSummary module flag support Unlike "CG Profile", LLVM proper does not verify the content of the metadata, but returns an empty one in case it's ill-formed. To that intent the importer here does a significant amount of checks to avoid consuming bad content. --- .../mlir/Dialect/LLVMIR/LLVMAttrDefs.td | 48 ++++ .../mlir/Dialect/LLVMIR/LLVMDialect.td | 3 + mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp | 13 + .../LLVMIR/LLVMToLLVMIRTranslation.cpp | 68 +++++ mlir/lib/Target/LLVMIR/ModuleImport.cpp | 259 +++++++++++++++++- mlir/test/Dialect/LLVMIR/invalid.mlir | 7 + .../test/Dialect/LLVMIR/module-roundtrip.mlir | 26 +- .../Target/LLVMIR/Import/import-failure.ll | 124 +++++++++ .../test/Target/LLVMIR/Import/module-flags.ll | 67 ++++- mlir/test/Target/LLVMIR/llvmir.mlir | 27 ++ 10 files changed, 627 insertions(+), 15 deletions(-) diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td index 7d6d38ecad897..5eb66745db829 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td @@ -1378,6 +1378,54 @@ def ModuleFlagCGProfileEntryAttr let assemblyFormat = "`<` struct(params) `>`"; } +def ModuleFlagProfileSummaryDetailedAttr + : LLVM_Attr<"ModuleFlagProfileSummaryDetailed", "profile_summary_detailed"> { + let summary = "ProfileSummary detailed information"; + let description = [{ + Contains detailed information pertinent to "ProfileSummary" attribute. + A `#llvm.profile_summary` may contain several of it. + ```mlir + llvm.module_flags [ ... + detailed_summary = [ + #llvm.profile_summary_detailed, + #llvm.profile_summary_detailed + ``` + }]; + let parameters = (ins "uint32_t":$cut_off, + "uint64_t":$min_count, + "uint32_t":$num_counts); + let assemblyFormat = "`<` struct(params) `>`"; +} + +def ModuleFlagProfileSummaryAttr + : LLVM_Attr<"ModuleFlagProfileSummary", "profile_summary"> { + let summary = "ProfileSummary module flag"; + let description = [{ + Describes ProfileSummary gathered data in a module. Example: + ```mlir + llvm.module_flags [#llvm.mlir.module_flag, + #llvm.profile_summary_detailed + ]>>] + ``` + }]; + let parameters = ( + ins "StringAttr":$format, "uint64_t":$total_count, "uint64_t":$max_count, + "uint64_t":$max_internal_count, "uint64_t":$max_function_count, + "uint64_t":$num_counts, "uint64_t":$num_functions, + OptionalParameter<"IntegerAttr">:$is_partial_profile, + OptionalParameter<"FloatAttr">:$partial_profile_ratio, + "ArrayAttr":$detailed_summary); + + let assemblyFormat = "`<` struct(params) `>`"; +} + //===----------------------------------------------------------------------===// // LLVM_DependentLibrariesAttr //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.td index 9f9d075a3eebf..b5ea8fc5da500 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.td @@ -92,6 +92,9 @@ def LLVM_Dialect : Dialect { static StringRef getModuleFlagKeyCGProfileName() { return "CG Profile"; } + static StringRef getModuleFlagKeyProfileSummaryName() { + return "ProfileSummary"; + } /// Returns `true` if the given type is compatible with the LLVM dialect. static bool isCompatibleType(Type); diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp index ffde597ac83c1..ef689e3721d91 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp @@ -390,6 +390,19 @@ ModuleFlagAttr::verify(function_ref emitError, return success(); } + if (key == LLVMDialect::getModuleFlagKeyProfileSummaryName()) { + if (auto summaryAttr = dyn_cast(value)) { + StringRef fmt = summaryAttr.getFormat().getValue(); + if (fmt != "SampleProfile" && fmt != "InstrProf" && fmt != "CSInstrProf") + return emitError() << "'ProfileFormat' must be 'SampleProfile', " + "'InstrProf' or 'CSInstrProf'"; + } else { + return emitError() << "'ProfileSummary' key expects a " + "'#llvm.profile_summary' attribute"; + } + return success(); + } + if (isa(value)) return success(); diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp index 35dcde2a33d41..260d61f97fce5 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp @@ -300,9 +300,72 @@ convertModuleFlagValue(StringRef key, ArrayAttr arrayAttr, } return llvm::MDTuple::getDistinct(context, nodes); } + return nullptr; } +static llvm::Metadata *convertModuleFlagProfileSummaryAttr( + StringRef key, ModuleFlagProfileSummaryAttr summaryAttr, + llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation) { + llvm::LLVMContext &context = builder.getContext(); + llvm::MDBuilder mdb(context); + SmallVector summaryNodes; + + auto getIntTuple = [&](StringRef key, uint64_t val) -> llvm::MDTuple * { + SmallVector tupleNodes{ + mdb.createString(key), mdb.createConstant(llvm::ConstantInt::get( + llvm::Type::getInt64Ty(context), val))}; + return llvm::MDTuple::get(context, tupleNodes); + }; + + SmallVector fmtNode{ + mdb.createString("ProfileFormat"), + mdb.createString(summaryAttr.getFormat().getValue())}; + + SmallVector vals = { + llvm::MDTuple::get(context, fmtNode), + getIntTuple("TotalCount", summaryAttr.getTotalCount()), + getIntTuple("MaxCount", summaryAttr.getMaxCount()), + getIntTuple("MaxInternalCount", summaryAttr.getMaxInternalCount()), + getIntTuple("MaxFunctionCount", summaryAttr.getMaxFunctionCount()), + getIntTuple("NumCounts", summaryAttr.getNumCounts()), + getIntTuple("NumFunctions", summaryAttr.getNumFunctions()), + }; + + if (summaryAttr.getIsPartialProfile()) + vals.push_back(getIntTuple("IsPartialProfile", + summaryAttr.getIsPartialProfile().getUInt())); + + if (summaryAttr.getPartialProfileRatio()) { + SmallVector tupleNodes{ + mdb.createString("PartialProfileRatio"), + mdb.createConstant(llvm::ConstantFP::get( + llvm::Type::getDoubleTy(context), + summaryAttr.getPartialProfileRatio().getValue()))}; + vals.push_back(llvm::MDTuple::get(context, tupleNodes)); + } + + SmallVector detailedEntries; + for (auto detailedEntry : + summaryAttr.getDetailedSummary() + .getAsRange()) { + SmallVector tupleNodes{ + mdb.createConstant(llvm::ConstantInt::get( + llvm::Type::getInt64Ty(context), detailedEntry.getCutOff())), + mdb.createConstant(llvm::ConstantInt::get( + llvm::Type::getInt64Ty(context), detailedEntry.getMinCount())), + mdb.createConstant(llvm::ConstantInt::get( + llvm::Type::getInt64Ty(context), detailedEntry.getNumCounts()))}; + detailedEntries.push_back(llvm::MDTuple::get(context, tupleNodes)); + } + SmallVector detailedSummary{ + mdb.createString("DetailedSummary"), + llvm::MDTuple::get(context, detailedEntries)}; + vals.push_back(llvm::MDTuple::get(context, detailedSummary)); + + return llvm::MDNode::get(context, vals); +} + static void convertModuleFlagsOp(ArrayAttr flags, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation) { llvm::Module *llvmModule = moduleTranslation.getLLVMModule(); @@ -323,6 +386,11 @@ static void convertModuleFlagsOp(ArrayAttr flags, llvm::IRBuilderBase &builder, arrayAttr, builder, moduleTranslation); }) + .Case([&](auto summaryAttr) { + return convertModuleFlagProfileSummaryAttr( + flagAttr.getKey().getValue(), summaryAttr, builder, + moduleTranslation); + }) .Default([](auto) { return nullptr; }); assert(valueMetadata && "expected valid metadata"); diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp index 0b77a3d23d392..fff2ae4a65f2d 100644 --- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp @@ -554,13 +554,262 @@ static Attribute convertCGProfileModuleFlagValue(ModuleOp mlirModule, return ArrayAttr::get(mlirModule->getContext(), cgProfile); } +static Attribute +convertProfileSummaryModuleFlagValue(ModuleOp mlirModule, + const llvm::Module *llvmModule, + llvm::MDTuple *mdTuple) { + unsigned profileNumEntries = mdTuple->getNumOperands(); + if (profileNumEntries < 8) { + emitWarning(mlirModule.getLoc()) + << "expected at 8 entries in 'ProfileSummary': " + << diagMD(mdTuple, llvmModule); + return nullptr; + } + + unsigned summayIdx = 0; + + auto getMDTuple = [&](const llvm::MDOperand &md) { + auto *tupleEntry = dyn_cast_or_null(md); + if (!tupleEntry || tupleEntry->getNumOperands() != 2) + emitWarning(mlirModule.getLoc()) + << "expected 2-element tuple metadata: " << diagMD(md, llvmModule); + return tupleEntry; + }; + + auto getFormat = [&](const llvm::MDOperand &formatMD) -> StringAttr { + auto *tupleEntry = getMDTuple(formatMD); + if (!tupleEntry) + return nullptr; + + llvm::MDString *keyMD = dyn_cast(tupleEntry->getOperand(0)); + if (!keyMD || keyMD->getString() != "ProfileFormat") { + emitWarning(mlirModule.getLoc()) + << "expected 'ProfileFormat' key: " + << diagMD(tupleEntry->getOperand(0), llvmModule); + return nullptr; + } + + llvm::MDString *valMD = dyn_cast(tupleEntry->getOperand(1)); + auto formatAttr = llvm::StringSwitch(valMD->getString()) + .Case("SampleProfile", "SampleProfile") + .Case("InstrProf", "InstrProf") + .Case("CSInstrProf", "CSInstrProf") + .Default(""); + if (formatAttr.empty()) { + emitWarning(mlirModule.getLoc()) + << "expected 'SampleProfile', 'InstrProf' or 'CSInstrProf' values, " + "but found: " + << diagMD(valMD, llvmModule); + return nullptr; + } + + return StringAttr::get(mlirModule->getContext(), formatAttr); + }; + + auto getConstantMD = [&](const llvm::MDOperand &md, StringRef matchKey, + bool optional = + false) -> llvm::ConstantAsMetadata * { + auto *tupleEntry = getMDTuple(md); + if (!tupleEntry) + return nullptr; + llvm::MDString *keyMD = dyn_cast(tupleEntry->getOperand(0)); + if (!keyMD || keyMD->getString() != matchKey) { + if (!optional) + emitWarning(mlirModule.getLoc()) + << "expected '" << matchKey << "' key, but found: " + << diagMD(tupleEntry->getOperand(0), llvmModule); + return nullptr; + } + + return dyn_cast(tupleEntry->getOperand(1)); + }; + + auto checkOptionalPosition = [&](const llvm::MDOperand &md, + StringRef matchKey) -> LogicalResult { + // Make sure we won't step over the bound of the array of summary entries. + // Since (non-optional) DetailedSummary always comes last, the next entry in + // the tuple operand array must exist. + if (summayIdx + 1 >= profileNumEntries) { + emitWarning(mlirModule.getLoc()) + << "the last summary entry is '" << matchKey + << "', expected 'DetailedSummary': " << diagMD(md, llvmModule); + return failure(); + } + + return success(); + }; + + auto getInt64Value = [&](const llvm::MDOperand &md, StringRef matchKey, + uint64_t &val) { + auto *valMD = getConstantMD(md, matchKey); + if (!valMD) + return false; + + if (auto *cstInt = dyn_cast(valMD->getValue())) { + val = cstInt->getZExtValue(); + return true; + } + + emitWarning(mlirModule.getLoc()) + << "expected integer metadata value for key '" << matchKey + << "': " << diagMD(md, llvmModule); + return false; + }; + + auto getOptIntValue = [&](const llvm::MDOperand &md, StringRef matchKey, + IntegerAttr &attr) -> LogicalResult { + if (!getConstantMD(md, matchKey, /*optional=*/true)) + return success(); + if (checkOptionalPosition(md, matchKey).failed()) + return failure(); + uint64_t val = 0; + if (!getInt64Value(md, matchKey, val)) + return failure(); + attr = + IntegerAttr::get(IntegerType::get(mlirModule->getContext(), 64), val); + return success(); + }; + + auto getOptDoubleValue = [&](const llvm::MDOperand &md, StringRef matchKey, + FloatAttr &attr) -> LogicalResult { + auto *valMD = getConstantMD(md, matchKey, /*optional=*/true); + if (!valMD) + return success(); + if (auto *cstFP = dyn_cast(valMD->getValue())) { + if (checkOptionalPosition(md, matchKey).failed()) + return failure(); + attr = FloatAttr::get(Float64Type::get(mlirModule.getContext()), + cstFP->getValueAPF()); + return success(); + } + emitWarning(mlirModule.getLoc()) + << "expected double metadata value for key '" << matchKey + << "': " << diagMD(md, llvmModule); + return failure(); + }; + + auto getSummary = [&](const llvm::MDOperand &summaryMD) -> ArrayAttr { + auto *tupleEntry = getMDTuple(summaryMD); + if (!tupleEntry) + return nullptr; + + llvm::MDString *keyMD = dyn_cast(tupleEntry->getOperand(0)); + if (!keyMD || keyMD->getString() != "DetailedSummary") { + emitWarning(mlirModule.getLoc()) + << "expected 'DetailedSummary' key: " + << diagMD(tupleEntry->getOperand(0), llvmModule); + return nullptr; + } + + llvm::MDTuple *entriesMD = + dyn_cast(tupleEntry->getOperand(1)); + if (!entriesMD) { + emitWarning(mlirModule.getLoc()) + << "expected tuple value for 'DetailedSummary' key: " + << diagMD(tupleEntry->getOperand(1), llvmModule); + return nullptr; + } + + SmallVector detailedSummary; + for (auto &&entry : entriesMD->operands()) { + llvm::MDTuple *entryMD = dyn_cast(entry); + if (!entryMD || entryMD->getNumOperands() != 3) { + emitWarning(mlirModule.getLoc()) + << "'DetailedSummary' entry expects 3 operands: " + << diagMD(entry, llvmModule); + return nullptr; + } + llvm::ConstantAsMetadata *op0 = + dyn_cast(entryMD->getOperand(0)); + llvm::ConstantAsMetadata *op1 = + dyn_cast(entryMD->getOperand(1)); + llvm::ConstantAsMetadata *op2 = + dyn_cast(entryMD->getOperand(2)); + + if (!op0 || !op1 || !op2) { + emitWarning(mlirModule.getLoc()) + << "expected only integer entries in 'DetailedSummary': " + << diagMD(entry, llvmModule); + return nullptr; + } + + auto detaildSummaryEntry = ModuleFlagProfileSummaryDetailedAttr::get( + mlirModule->getContext(), + cast(op0->getValue())->getZExtValue(), + cast(op1->getValue())->getZExtValue(), + cast(op2->getValue())->getZExtValue()); + detailedSummary.push_back(detaildSummaryEntry); + } + return ArrayAttr::get(mlirModule->getContext(), detailedSummary); + }; + + // Build ModuleFlagProfileSummaryAttr by sequentially fetching elements in + // a fixed order: format, total count, etc. + SmallVector profileSummary; + StringAttr format = getFormat(mdTuple->getOperand(summayIdx++)); + if (!format) + return nullptr; + + uint64_t totalCount = 0, maxCount = 0, maxInternalCount = 0, + maxFunctionCount = 0, numCounts = 0, numFunctions = 0; + if (!getInt64Value(mdTuple->getOperand(summayIdx++), "TotalCount", + totalCount)) + return nullptr; + if (!getInt64Value(mdTuple->getOperand(summayIdx++), "MaxCount", maxCount)) + return nullptr; + if (!getInt64Value(mdTuple->getOperand(summayIdx++), "MaxInternalCount", + maxInternalCount)) + return nullptr; + if (!getInt64Value(mdTuple->getOperand(summayIdx++), "MaxFunctionCount", + maxFunctionCount)) + return nullptr; + if (!getInt64Value(mdTuple->getOperand(summayIdx++), "NumCounts", numCounts)) + return nullptr; + if (!getInt64Value(mdTuple->getOperand(summayIdx++), "NumFunctions", + numFunctions)) + return nullptr; + + // Handle optional keys. + IntegerAttr isPartialProfile; + if (getOptIntValue(mdTuple->getOperand(summayIdx), "IsPartialProfile", + isPartialProfile) + .failed()) + return nullptr; + if (isPartialProfile) + summayIdx++; + + FloatAttr partialProfileRatio; + if (getOptDoubleValue(mdTuple->getOperand(summayIdx), "PartialProfileRatio", + partialProfileRatio) + .failed()) + return nullptr; + if (partialProfileRatio) + summayIdx++; + + // Handle detailed summary. + ArrayAttr detailedSummary = getSummary(mdTuple->getOperand(summayIdx)); + if (!detailedSummary) + return nullptr; + + // Build the final profile summary attribute. + return ModuleFlagProfileSummaryAttr::get( + mlirModule->getContext(), format, totalCount, maxCount, maxInternalCount, + maxFunctionCount, numCounts, numFunctions, + isPartialProfile ? isPartialProfile : nullptr, + partialProfileRatio ? partialProfileRatio : nullptr, detailedSummary); +} + /// Invoke specific handlers for each known module flag value, returns nullptr /// if the key is unknown or unimplemented. -static Attribute convertModuleFlagValueFromMDTuple(ModuleOp mlirModule, - StringRef key, - llvm::MDTuple *mdTuple) { +static Attribute +convertModuleFlagValueFromMDTuple(ModuleOp mlirModule, + const llvm::Module *llvmModule, StringRef key, + llvm::MDTuple *mdTuple) { if (key == LLVMDialect::getModuleFlagKeyCGProfileName()) return convertCGProfileModuleFlagValue(mlirModule, mdTuple); + if (key == LLVMDialect::getModuleFlagKeyProfileSummaryName()) + return convertProfileSummaryModuleFlagValue(mlirModule, llvmModule, + mdTuple); return nullptr; } @@ -576,8 +825,8 @@ LogicalResult ModuleImport::convertModuleFlagsMetadata() { } else if (auto *mdString = dyn_cast(val)) { valAttr = builder.getStringAttr(mdString->getString()); } else if (auto *mdTuple = dyn_cast(val)) { - valAttr = convertModuleFlagValueFromMDTuple(mlirModule, key->getString(), - mdTuple); + valAttr = convertModuleFlagValueFromMDTuple(mlirModule, llvmModule.get(), + key->getString(), mdTuple); } if (!valAttr) { diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir index 5dea94026b248..84c0d40c8b346 100644 --- a/mlir/test/Dialect/LLVMIR/invalid.mlir +++ b/mlir/test/Dialect/LLVMIR/invalid.mlir @@ -1800,6 +1800,13 @@ module { // ----- +module { + // expected-error@below {{'ProfileSummary' key expects a '#llvm.profile_summary' attribute}} + llvm.module_flags [#llvm.mlir.module_flag] +} + +// ----- + llvm.func @t0() -> !llvm.ptr { %0 = llvm.blockaddress > : !llvm.ptr llvm.blocktag diff --git a/mlir/test/Dialect/LLVMIR/module-roundtrip.mlir b/mlir/test/Dialect/LLVMIR/module-roundtrip.mlir index 025d9b2287c42..62a16de6b6d97 100644 --- a/mlir/test/Dialect/LLVMIR/module-roundtrip.mlir +++ b/mlir/test/Dialect/LLVMIR/module-roundtrip.mlir @@ -11,7 +11,17 @@ module { #llvm.cgprofile_entry, #llvm.cgprofile_entry, #llvm.cgprofile_entry - ]>] + ]>, + #llvm.mlir.module_flag, + #llvm.profile_summary_detailed + ]>>] } // CHECK: llvm.module_flags [ @@ -25,4 +35,16 @@ module { // CHECK-SAME: #llvm.cgprofile_entry, // CHECK-SAME: #llvm.cgprofile_entry, // CHECK-SAME: #llvm.cgprofile_entry -// CHECK-SAME: ]>] +// CHECK-SAME: ]>, +// CHECK-SAME: #llvm.mlir.module_flag, +// CHECK-SAME: #llvm.profile_summary_detailed +// CHECK-SAME: ]>>] + +llvm.module_flags [] diff --git a/mlir/test/Target/LLVMIR/Import/import-failure.ll b/mlir/test/Target/LLVMIR/Import/import-failure.ll index 782925a0a938e..7571158a57d14 100644 --- a/mlir/test/Target/LLVMIR/Import/import-failure.ll +++ b/mlir/test/Target/LLVMIR/Import/import-failure.ll @@ -348,3 +348,127 @@ define void @fn() { bb1: ret void } + +; // ----- + +!10 = !{ i32 1, !"foo", i32 1 } +!11 = !{ i32 4, !"bar", i32 37 } +!12 = !{ i32 2, !"qux", i32 42 } +; CHECK: unsupported module flag value for key 'qux' : !4 = !{!"foo", i32 1} +!13 = !{ i32 3, !"qux", !{ !"foo", i32 1 }} +!llvm.module.flags = !{ !10, !11, !12, !13 } + +; // ----- + +!llvm.module.flags = !{!41873} + +!41873 = !{i32 1, !"ProfileSummary", !41874} +!41874 = !{!41875, !41876, !41877, !41878, !41880, !41881, !41882, !41883, !41884} +!41875 = !{!"ProfileFormat", !"InstrProf"} +!41876 = !{!"TotalCount", i64 263646} +!41877 = !{!"MaxCount", i64 86427} +!41878 = !{!"MaxInternalCount", i64 86427} +; CHECK: expected 'MaxFunctionCount' key, but found: !"NumCounts" +!41880 = !{!"NumCounts", i64 3712} +!41881 = !{!"NumFunctions", i64 796} +!41882 = !{!"IsPartialProfile", i64 0} +!41883 = !{!"PartialProfileRatio", double 0.000000e+00} +!41884 = !{!"DetailedSummary", !41885} +!41885 = !{!41886, !41887} +!41886 = !{i32 10000, i64 86427, i32 1} +!41887 = !{i32 100000, i64 86427, i32 1} + +; // ----- + +!llvm.module.flags = !{!51873} + +!51873 = !{i32 1, !"ProfileSummary", !51874} +!51874 = !{!51875, !51876, !51877, !51878, !51879, !51880, !51881, !51882, !51883, !51884} +!51875 = !{!"ProfileFormat", !"InstrProf"} +!51876 = !{!"TotalCount", i64 263646} +!51877 = !{!"MaxCount", i64 86427} +!51878 = !{!"MaxInternalCount", i64 86427} +!51879 = !{!"MaxFunctionCount", i64 4691} +!51880 = !{!"NumCounts", i64 3712} +; CHECK: expected integer metadata value for key 'NumFunctions' +!51881 = !{!"NumFunctions", double 0.000000e+00} +!51882 = !{!"IsPartialProfile", i64 0} +!51883 = !{!"PartialProfileRatio", double 0.000000e+00} +!51884 = !{!"DetailedSummary", !51885} +!51885 = !{!51886, !51887} +!51886 = !{i32 10000, i64 86427, i32 1} +!51887 = !{i32 100000, i64 86427, i32 1} + +; // ----- + +!llvm.module.flags = !{!61873} + +!61873 = !{i32 1, !"ProfileSummary", !61874} +!61874 = !{!61875, !61876, !61877, !61878, !61879, !61880, !61881, !61882, !61883, !61884} +; CHECK: expected 'SampleProfile', 'InstrProf' or 'CSInstrProf' values, but found: !"MyThingyFmt" +!61875 = !{!"ProfileFormat", !"MyThingyFmt"} +!61876 = !{!"TotalCount", i64 263646} +!61877 = !{!"MaxCount", i64 86427} +!61878 = !{!"MaxInternalCount", i64 86427} +!61879 = !{!"MaxFunctionCount", i64 4691} +!61880 = !{!"NumCounts", i64 3712} +!61881 = !{!"NumFunctions", i64 796} +!61882 = !{!"IsPartialProfile", i64 0} +!61883 = !{!"PartialProfileRatio", double 0.000000e+00} +!61884 = !{!"DetailedSummary", !61885} +!61885 = !{!61886, !61887} +!61886 = !{i32 10000, i64 86427, i32 1} +!61887 = !{i32 100000, i64 86427, i32 1} + +; // ----- + +!llvm.module.flags = !{!71873} + +!71873 = !{i32 1, !"ProfileSummary", !71874} +!71874 = !{!71875, !71876, !71877, !71878, !71879, !71880, !71881, !71882, !71883} +!71875 = !{!"ProfileFormat", !"InstrProf"} +!71876 = !{!"TotalCount", i64 263646} +!71877 = !{!"MaxCount", i64 86427} +!71878 = !{!"MaxInternalCount", i64 86427} +!71879 = !{!"MaxFunctionCount", i64 4691} +!71880 = !{!"NumCounts", i64 3712} +!71881 = !{!"NumFunctions", i64 796} +!71882 = !{!"IsPartialProfile", i64 0} +; CHECK: the last summary entry is 'PartialProfileRatio', expected 'DetailedSummary' +!71883 = !{!"PartialProfileRatio", double 0.000000e+00} + +; // ----- + +!llvm.module.flags = !{!81873} + +!81873 = !{i32 1, !"ProfileSummary", !81874} +; CHECK: expected at 8 entries in 'ProfileSummary' +!81874 = !{!81875, !81876, !81877, !81878, !81879, !81880, !81881} +!81875 = !{!"ProfileFormat", !"InstrProf"} +!81876 = !{!"TotalCount", i64 263646} +!81877 = !{!"MaxCount", i64 86427} +!81878 = !{!"MaxInternalCount", i64 86427} +!81879 = !{!"MaxFunctionCount", i64 4691} +!81880 = !{!"NumCounts", i64 3812} +!81881 = !{!"NumFunctions", i64 796} + +; // ----- + +!llvm.module.flags = !{!91873} + +!91873 = !{i32 1, !"ProfileSummary", !91874} +!91874 = !{!91875, !91876, !91877, !91878, !91879, !91880, !91881, !91882, !91883, !91884} +!91875 = !{!"ProfileFormat", !"InstrProf"} +; CHECK: expected 2-element tuple metadata +!91876 = !{!"TotalCount", i64 263646, i64 263646} +!91877 = !{!"MaxCount", i64 86427} +!91878 = !{!"MaxInternalCount", i64 86427} +!91879 = !{!"MaxFunctionCount", i64 4691} +!91880 = !{!"NumCounts", i64 3712} +!91881 = !{!"NumFunctions", i64 796} +!91882 = !{!"IsPartialProfile", i64 0} +!91883 = !{!"PartialProfileRatio", double 0.000000e+00} +!91884 = !{!"DetailedSummary", !91885} +!91885 = !{!91886, !91887} +!91886 = !{i32 10000, i64 86427, i32 1} +!91887 = !{i32 100000, i64 86427, i32 1} diff --git a/mlir/test/Target/LLVMIR/Import/module-flags.ll b/mlir/test/Target/LLVMIR/Import/module-flags.ll index 09e708de0cc93..49895c4f26241 100644 --- a/mlir/test/Target/LLVMIR/Import/module-flags.ll +++ b/mlir/test/Target/LLVMIR/Import/module-flags.ll @@ -18,14 +18,6 @@ ; CHECK-SAME: #llvm.mlir.module_flag, ; CHECK-SAME: #llvm.mlir.module_flag] -; // ----- -; expected-warning@-2 {{unsupported module flag value for key 'qux' : !4 = !{!"foo", i32 1}}} -!10 = !{ i32 1, !"foo", i32 1 } -!11 = !{ i32 4, !"bar", i32 37 } -!12 = !{ i32 2, !"qux", i32 42 } -!13 = !{ i32 3, !"qux", !{ !"foo", i32 1 }} -!llvm.module.flags = !{ !10, !11, !12, !13 } - ; // ----- declare void @from(i32) @@ -44,3 +36,62 @@ declare void @to() ; CHECK-SAME: #llvm.cgprofile_entry, ; CHECK-SAME: #llvm.cgprofile_entry ; CHECK-SAME: ]>] + +; // ----- + +!llvm.module.flags = !{!31873} + +!31873 = !{i32 1, !"ProfileSummary", !31874} +!31874 = !{!31875, !31876, !31877, !31878, !31879, !31880, !31881, !31882, !31883, !31884} +!31875 = !{!"ProfileFormat", !"InstrProf"} +!31876 = !{!"TotalCount", i64 263646} +!31877 = !{!"MaxCount", i64 86427} +!31878 = !{!"MaxInternalCount", i64 86427} +!31879 = !{!"MaxFunctionCount", i64 4691} +!31880 = !{!"NumCounts", i64 3712} +!31881 = !{!"NumFunctions", i64 796} +!31882 = !{!"IsPartialProfile", i64 0} +!31883 = !{!"PartialProfileRatio", double 0.000000e+00} +!31884 = !{!"DetailedSummary", !31885} +!31885 = !{!31886, !31887} +!31886 = !{i32 10000, i64 86427, i32 1} +!31887 = !{i32 100000, i64 86427, i32 1} + +; CHECK: llvm.module_flags [#llvm.mlir.module_flag, +; CHECK-SAME: #llvm.profile_summary_detailed +; CHECK-SAME: ]>>] + +; // ----- + +; Test optional fields + +!llvm.module.flags = !{!41873} + +!41873 = !{i32 1, !"ProfileSummary", !41874} +!41874 = !{!41875, !41876, !41877, !41878, !41879, !41880, !41881, !41884} +!41875 = !{!"ProfileFormat", !"InstrProf"} +!41876 = !{!"TotalCount", i64 263646} +!41877 = !{!"MaxCount", i64 86427} +!41878 = !{!"MaxInternalCount", i64 86427} +!41879 = !{!"MaxFunctionCount", i64 4691} +!41880 = !{!"NumCounts", i64 3712} +!41881 = !{!"NumFunctions", i64 796} +!41884 = !{!"DetailedSummary", !41885} +!41885 = !{!41886, !41887} +!41886 = !{i32 10000, i64 86427, i32 1} +!41887 = !{i32 100000, i64 86427, i32 1} + +; CHECK: llvm.module_flags [#llvm.mlir.module_flag, +; CHECK-SAME: #llvm.profile_summary_detailed +; CHECK-SAME: ]>>] diff --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir index 9852c4051f0d0..dc347430eb0b7 100644 --- a/mlir/test/Target/LLVMIR/llvmir.mlir +++ b/mlir/test/Target/LLVMIR/llvmir.mlir @@ -2883,6 +2883,33 @@ llvm.func @to() // ----- +llvm.module_flags [#llvm.mlir.module_flag, + #llvm.profile_summary_detailed + ]>>] + +// CHECK: !llvm.module.flags = !{!0, !15} + +// CHECK: !0 = !{i32 1, !"ProfileSummary", !1} +// CHECK: !1 = !{!2, !3, !4, !5, !6, !7, !8, !9, !10, !11} +// CHECK: !2 = !{!"ProfileFormat", !"InstrProf"} +// CHECK: !3 = !{!"TotalCount", i64 263646} +// CHECK: !4 = !{!"MaxCount", i64 86427} +// CHECK: !5 = !{!"MaxInternalCount", i64 86427} +// CHECK: !6 = !{!"MaxFunctionCount", i64 4691} +// CHECK: !7 = !{!"NumCounts", i64 3712} +// CHECK: !8 = !{!"NumFunctions", i64 796} +// CHECK: !9 = !{!"IsPartialProfile", i64 0} +// CHECK: !10 = !{!"PartialProfileRatio", double 0.000000e+00} + +// ----- + module attributes {llvm.dependent_libraries = ["foo", "bar"]} {} // CHECK: !llvm.dependent-libraries = !{![[#LIBFOO:]], ![[#LIBBAR:]]} From b661bd68a0d4d360486d44f0b8b231fa0b6c726f Mon Sep 17 00:00:00 2001 From: Bruno Cardoso Lopes Date: Thu, 1 May 2025 14:37:48 -0700 Subject: [PATCH 02/14] add verifier test for format --- mlir/test/Dialect/LLVMIR/invalid.mlir | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir index 84c0d40c8b346..8f3fe03a0303e 100644 --- a/mlir/test/Dialect/LLVMIR/invalid.mlir +++ b/mlir/test/Dialect/LLVMIR/invalid.mlir @@ -1807,6 +1807,20 @@ module { // ----- +// expected-error@below {{'ProfileFormat' must be 'SampleProfile', 'InstrProf' or 'CSInstrProf'}} +llvm.module_flags [#llvm.mlir.module_flag, + #llvm.profile_summary_detailed +]>>] + +// ----- + llvm.func @t0() -> !llvm.ptr { %0 = llvm.blockaddress > : !llvm.ptr llvm.blocktag From abfcb67658004bcdd8775d08d56756d993c0decd Mon Sep 17 00:00:00 2001 From: Bruno Cardoso Lopes Date: Thu, 1 May 2025 14:51:33 -0700 Subject: [PATCH 03/14] Use ArrayRefParameter --- .../mlir/Dialect/LLVMIR/LLVMAttrDefs.td | 16 +++++++------- .../LLVMIR/LLVMToLLVMIRTranslation.cpp | 4 +--- mlir/lib/Target/LLVMIR/ModuleImport.cpp | 21 ++++++++++--------- mlir/test/Dialect/LLVMIR/invalid.mlir | 8 +++---- .../test/Dialect/LLVMIR/module-roundtrip.mlir | 16 +++++++------- .../test/Target/LLVMIR/Import/module-flags.ll | 16 +++++++------- mlir/test/Target/LLVMIR/llvmir.mlir | 8 +++---- 7 files changed, 44 insertions(+), 45 deletions(-) diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td index 5eb66745db829..ef05e884b62fe 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td @@ -1386,9 +1386,9 @@ def ModuleFlagProfileSummaryDetailedAttr A `#llvm.profile_summary` may contain several of it. ```mlir llvm.module_flags [ ... - detailed_summary = [ - #llvm.profile_summary_detailed, - #llvm.profile_summary_detailed + detailed_summary = + , + ``` }]; let parameters = (ins "uint32_t":$cut_off, @@ -1409,10 +1409,10 @@ def ModuleFlagProfileSummaryAttr num_counts = 3712, num_functions = 796, is_partial_profile = 0 : i64, partial_profile_ratio = 0.000000e+00 : f64, - detailed_summary = [ - #llvm.profile_summary_detailed, - #llvm.profile_summary_detailed - ]>>] + detailed_summary = + , + + >>] ``` }]; let parameters = ( @@ -1421,7 +1421,7 @@ def ModuleFlagProfileSummaryAttr "uint64_t":$num_counts, "uint64_t":$num_functions, OptionalParameter<"IntegerAttr">:$is_partial_profile, OptionalParameter<"FloatAttr">:$partial_profile_ratio, - "ArrayAttr":$detailed_summary); + ArrayRefParameter<"ModuleFlagProfileSummaryDetailedAttr">:$detailed_summary); let assemblyFormat = "`<` struct(params) `>`"; } diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp index 260d61f97fce5..6a523debf1b3a 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp @@ -346,9 +346,7 @@ static llvm::Metadata *convertModuleFlagProfileSummaryAttr( } SmallVector detailedEntries; - for (auto detailedEntry : - summaryAttr.getDetailedSummary() - .getAsRange()) { + for (auto detailedEntry : summaryAttr.getDetailedSummary()) { SmallVector tupleNodes{ mdb.createConstant(llvm::ConstantInt::get( llvm::Type::getInt64Ty(context), detailedEntry.getCutOff())), diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp index fff2ae4a65f2d..72ffe2d09e4f8 100644 --- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp @@ -688,17 +688,19 @@ convertProfileSummaryModuleFlagValue(ModuleOp mlirModule, return failure(); }; - auto getSummary = [&](const llvm::MDOperand &summaryMD) -> ArrayAttr { + auto getSummary = [&](const llvm::MDOperand &summaryMD, + SmallVectorImpl + &detailedSummary) { auto *tupleEntry = getMDTuple(summaryMD); if (!tupleEntry) - return nullptr; + return false; llvm::MDString *keyMD = dyn_cast(tupleEntry->getOperand(0)); if (!keyMD || keyMD->getString() != "DetailedSummary") { emitWarning(mlirModule.getLoc()) << "expected 'DetailedSummary' key: " << diagMD(tupleEntry->getOperand(0), llvmModule); - return nullptr; + return false; } llvm::MDTuple *entriesMD = @@ -707,17 +709,16 @@ convertProfileSummaryModuleFlagValue(ModuleOp mlirModule, emitWarning(mlirModule.getLoc()) << "expected tuple value for 'DetailedSummary' key: " << diagMD(tupleEntry->getOperand(1), llvmModule); - return nullptr; + return false; } - SmallVector detailedSummary; for (auto &&entry : entriesMD->operands()) { llvm::MDTuple *entryMD = dyn_cast(entry); if (!entryMD || entryMD->getNumOperands() != 3) { emitWarning(mlirModule.getLoc()) << "'DetailedSummary' entry expects 3 operands: " << diagMD(entry, llvmModule); - return nullptr; + return false; } llvm::ConstantAsMetadata *op0 = dyn_cast(entryMD->getOperand(0)); @@ -730,7 +731,7 @@ convertProfileSummaryModuleFlagValue(ModuleOp mlirModule, emitWarning(mlirModule.getLoc()) << "expected only integer entries in 'DetailedSummary': " << diagMD(entry, llvmModule); - return nullptr; + return false; } auto detaildSummaryEntry = ModuleFlagProfileSummaryDetailedAttr::get( @@ -740,7 +741,7 @@ convertProfileSummaryModuleFlagValue(ModuleOp mlirModule, cast(op2->getValue())->getZExtValue()); detailedSummary.push_back(detaildSummaryEntry); } - return ArrayAttr::get(mlirModule->getContext(), detailedSummary); + return true; }; // Build ModuleFlagProfileSummaryAttr by sequentially fetching elements in @@ -787,8 +788,8 @@ convertProfileSummaryModuleFlagValue(ModuleOp mlirModule, summayIdx++; // Handle detailed summary. - ArrayAttr detailedSummary = getSummary(mdTuple->getOperand(summayIdx)); - if (!detailedSummary) + SmallVector detailedSummary; + if (!getSummary(mdTuple->getOperand(summayIdx), detailedSummary)) return nullptr; // Build the final profile summary attribute. diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir index 8f3fe03a0303e..4f35358c9486a 100644 --- a/mlir/test/Dialect/LLVMIR/invalid.mlir +++ b/mlir/test/Dialect/LLVMIR/invalid.mlir @@ -1814,10 +1814,10 @@ llvm.module_flags [#llvm.mlir.module_flag, - #llvm.profile_summary_detailed -]>>] + detailed_summary = + , + +>>] // ----- diff --git a/mlir/test/Dialect/LLVMIR/module-roundtrip.mlir b/mlir/test/Dialect/LLVMIR/module-roundtrip.mlir index 62a16de6b6d97..bd6162f15527c 100644 --- a/mlir/test/Dialect/LLVMIR/module-roundtrip.mlir +++ b/mlir/test/Dialect/LLVMIR/module-roundtrip.mlir @@ -18,10 +18,10 @@ module { num_counts = 3712, num_functions = 796, is_partial_profile = 0 : i64, partial_profile_ratio = 0.000000e+00 : f64, - detailed_summary = [ - #llvm.profile_summary_detailed, - #llvm.profile_summary_detailed - ]>>] + detailed_summary = + , + + >>] } // CHECK: llvm.module_flags [ @@ -42,9 +42,9 @@ module { // CHECK-SAME: num_counts = 3712, num_functions = 796, // CHECK-SAME: is_partial_profile = 0 : i64, // CHECK-SAME: partial_profile_ratio = 0.000000e+00 : f64, -// CHECK-SAME: detailed_summary = [ -// CHECK-SAME: #llvm.profile_summary_detailed, -// CHECK-SAME: #llvm.profile_summary_detailed -// CHECK-SAME: ]>>] +// CHECK-SAME: detailed_summary = +// CHECK-SAME: , +// CHECK-SAME: +// CHECK-SAME: >>] llvm.module_flags [] diff --git a/mlir/test/Target/LLVMIR/Import/module-flags.ll b/mlir/test/Target/LLVMIR/Import/module-flags.ll index 49895c4f26241..11df41a630d05 100644 --- a/mlir/test/Target/LLVMIR/Import/module-flags.ll +++ b/mlir/test/Target/LLVMIR/Import/module-flags.ll @@ -62,10 +62,10 @@ declare void @to() ; CHECK-SAME: max_count = 86427, max_internal_count = 86427, max_function_count = 4691, ; CHECK-SAME: num_counts = 3712, num_functions = 796, is_partial_profile = 0 : i64, ; CHECK-SAME: partial_profile_ratio = 0.000000e+00 : f64, -; CHECK-SAME: detailed_summary = [ -; CHECK-SAME: #llvm.profile_summary_detailed, -; CHECK-SAME: #llvm.profile_summary_detailed -; CHECK-SAME: ]>>] +; CHECK-SAME: detailed_summary = +; CHECK-SAME: , +; CHECK-SAME: +; CHECK-SAME: >>] ; // ----- @@ -91,7 +91,7 @@ declare void @to() ; CHECK-SAME: #llvm.profile_summary, -; CHECK-SAME: #llvm.profile_summary_detailed -; CHECK-SAME: ]>>] +; CHECK-SAME: detailed_summary = +; CHECK-SAME: , +; CHECK-SAME: +; CHECK-SAME: >>] diff --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir index dc347430eb0b7..a8da126e698ec 100644 --- a/mlir/test/Target/LLVMIR/llvmir.mlir +++ b/mlir/test/Target/LLVMIR/llvmir.mlir @@ -2889,10 +2889,10 @@ llvm.module_flags [#llvm.mlir.module_flag, - #llvm.profile_summary_detailed - ]>>] + detailed_summary = + , + + >>] // CHECK: !llvm.module.flags = !{!0, !15} From b4acc60168c16ba0592e624338f4b51059d82d9a Mon Sep 17 00:00:00 2001 From: Bruno Cardoso Lopes Date: Thu, 1 May 2025 15:07:35 -0700 Subject: [PATCH 04/14] Use std::optional when possible for attr params --- mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td | 4 ++-- .../Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp | 4 ++-- mlir/lib/Target/LLVMIR/ModuleImport.cpp | 16 +++++++--------- mlir/test/Dialect/LLVMIR/invalid.mlir | 2 +- mlir/test/Dialect/LLVMIR/module-roundtrip.mlir | 4 ++-- mlir/test/Target/LLVMIR/Import/module-flags.ll | 2 +- mlir/test/Target/LLVMIR/llvmir.mlir | 2 +- 7 files changed, 16 insertions(+), 18 deletions(-) diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td index ef05e884b62fe..5a037a767a75f 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td @@ -1407,7 +1407,7 @@ def ModuleFlagProfileSummaryAttr #llvm.profile_summary, @@ -1419,7 +1419,7 @@ def ModuleFlagProfileSummaryAttr ins "StringAttr":$format, "uint64_t":$total_count, "uint64_t":$max_count, "uint64_t":$max_internal_count, "uint64_t":$max_function_count, "uint64_t":$num_counts, "uint64_t":$num_functions, - OptionalParameter<"IntegerAttr">:$is_partial_profile, + OptionalParameter<"std::optional">:$is_partial_profile, OptionalParameter<"FloatAttr">:$partial_profile_ratio, ArrayRefParameter<"ModuleFlagProfileSummaryDetailedAttr">:$detailed_summary); diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp index 6a523debf1b3a..37f07475b3f02 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp @@ -333,8 +333,8 @@ static llvm::Metadata *convertModuleFlagProfileSummaryAttr( }; if (summaryAttr.getIsPartialProfile()) - vals.push_back(getIntTuple("IsPartialProfile", - summaryAttr.getIsPartialProfile().getUInt())); + vals.push_back( + getIntTuple("IsPartialProfile", *summaryAttr.getIsPartialProfile())); if (summaryAttr.getPartialProfileRatio()) { SmallVector tupleNodes{ diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp index 72ffe2d09e4f8..028e3ca5d903e 100644 --- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp @@ -657,16 +657,15 @@ convertProfileSummaryModuleFlagValue(ModuleOp mlirModule, }; auto getOptIntValue = [&](const llvm::MDOperand &md, StringRef matchKey, - IntegerAttr &attr) -> LogicalResult { + std::optional &val) -> LogicalResult { if (!getConstantMD(md, matchKey, /*optional=*/true)) return success(); if (checkOptionalPosition(md, matchKey).failed()) return failure(); - uint64_t val = 0; - if (!getInt64Value(md, matchKey, val)) + uint64_t tmpVal = 0; + if (!getInt64Value(md, matchKey, tmpVal)) return failure(); - attr = - IntegerAttr::get(IntegerType::get(mlirModule->getContext(), 64), val); + val = tmpVal; return success(); }; @@ -771,12 +770,12 @@ convertProfileSummaryModuleFlagValue(ModuleOp mlirModule, return nullptr; // Handle optional keys. - IntegerAttr isPartialProfile; + std::optional isPartialProfile; if (getOptIntValue(mdTuple->getOperand(summayIdx), "IsPartialProfile", isPartialProfile) .failed()) return nullptr; - if (isPartialProfile) + if (isPartialProfile.has_value()) summayIdx++; FloatAttr partialProfileRatio; @@ -795,8 +794,7 @@ convertProfileSummaryModuleFlagValue(ModuleOp mlirModule, // Build the final profile summary attribute. return ModuleFlagProfileSummaryAttr::get( mlirModule->getContext(), format, totalCount, maxCount, maxInternalCount, - maxFunctionCount, numCounts, numFunctions, - isPartialProfile ? isPartialProfile : nullptr, + maxFunctionCount, numCounts, numFunctions, isPartialProfile, partialProfileRatio ? partialProfileRatio : nullptr, detailedSummary); } diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir index 4f35358c9486a..bb730b28b947d 100644 --- a/mlir/test/Dialect/LLVMIR/invalid.mlir +++ b/mlir/test/Dialect/LLVMIR/invalid.mlir @@ -1812,7 +1812,7 @@ llvm.module_flags [#llvm.mlir.module_flag, diff --git a/mlir/test/Dialect/LLVMIR/module-roundtrip.mlir b/mlir/test/Dialect/LLVMIR/module-roundtrip.mlir index bd6162f15527c..148b1eb87fa75 100644 --- a/mlir/test/Dialect/LLVMIR/module-roundtrip.mlir +++ b/mlir/test/Dialect/LLVMIR/module-roundtrip.mlir @@ -16,7 +16,7 @@ module { #llvm.profile_summary, @@ -40,7 +40,7 @@ module { // CHECK-SAME: #llvm.profile_summary, diff --git a/mlir/test/Target/LLVMIR/Import/module-flags.ll b/mlir/test/Target/LLVMIR/Import/module-flags.ll index 11df41a630d05..8e6a47921ee38 100644 --- a/mlir/test/Target/LLVMIR/Import/module-flags.ll +++ b/mlir/test/Target/LLVMIR/Import/module-flags.ll @@ -60,7 +60,7 @@ declare void @to() ; CHECK: llvm.module_flags [#llvm.mlir.module_flag, diff --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir index a8da126e698ec..1d4fd6b1cfd67 100644 --- a/mlir/test/Target/LLVMIR/llvmir.mlir +++ b/mlir/test/Target/LLVMIR/llvmir.mlir @@ -2887,7 +2887,7 @@ llvm.module_flags [#llvm.mlir.module_flag, From def6b16f3955e6ed7383da74d422075cce32e792 Mon Sep 17 00:00:00 2001 From: Bruno Cardoso Lopes Date: Thu, 1 May 2025 16:32:25 -0700 Subject: [PATCH 05/14] Use enum kind for format --- .../mlir/Dialect/LLVMIR/LLVMAttrDefs.td | 16 +++++------ mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td | 17 +++++++++++- mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp | 8 +----- .../LLVMIR/LLVMToLLVMIRTranslation.cpp | 3 ++- mlir/lib/Target/LLVMIR/ModuleImport.cpp | 27 +++++++++---------- mlir/test/Dialect/LLVMIR/invalid.mlir | 4 ++- .../test/Dialect/LLVMIR/module-roundtrip.mlir | 4 +-- .../test/Target/LLVMIR/Import/module-flags.ll | 4 +-- mlir/test/Target/LLVMIR/llvmir.mlir | 2 +- 9 files changed, 48 insertions(+), 37 deletions(-) diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td index 5a037a767a75f..ade2b64c108ff 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td @@ -1404,7 +1404,7 @@ def ModuleFlagProfileSummaryAttr Describes ProfileSummary gathered data in a module. Example: ```mlir llvm.module_flags [#llvm.mlir.module_flag>] ``` }]; - let parameters = ( - ins "StringAttr":$format, "uint64_t":$total_count, "uint64_t":$max_count, - "uint64_t":$max_internal_count, "uint64_t":$max_function_count, - "uint64_t":$num_counts, "uint64_t":$num_functions, - OptionalParameter<"std::optional">:$is_partial_profile, - OptionalParameter<"FloatAttr">:$partial_profile_ratio, - ArrayRefParameter<"ModuleFlagProfileSummaryDetailedAttr">:$detailed_summary); + let parameters = (ins "ProfileSummaryFormatKind":$format, + "uint64_t":$total_count, "uint64_t":$max_count, + "uint64_t":$max_internal_count, "uint64_t":$max_function_count, + "uint64_t":$num_counts, "uint64_t":$num_functions, + OptionalParameter<"std::optional">:$is_partial_profile, + OptionalParameter<"FloatAttr">:$partial_profile_ratio, + ArrayRefParameter<"ModuleFlagProfileSummaryDetailedAttr">:$detailed_summary); let assemblyFormat = "`<` struct(params) `>`"; } diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td index 6c0fe363d5551..7f5052948ab6c 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td @@ -823,7 +823,7 @@ def FPExceptionBehaviorAttr : LLVM_EnumAttr< } //===----------------------------------------------------------------------===// -// Module Flag Behavior +// Module Flags //===----------------------------------------------------------------------===// // These values must match llvm::Module::ModFlagBehavior ones. @@ -855,6 +855,21 @@ def ModFlagBehaviorAttr : LLVM_EnumAttr< let cppNamespace = "::mlir::LLVM"; } +def LLVM_ProfileSummaryFormatSampleProfile : I64EnumAttrCase<"SampleProfile", + 0>; +def LLVM_ProfileSummaryFormatInstrProf : I64EnumAttrCase<"InstrProf", 1>; +def LLVM_ProfileSummaryFormatCSInstrProf : I64EnumAttrCase<"CSInstrProf", 2>; + +def LLVM_ProfileSummaryFormatKind : I64EnumAttr< + "ProfileSummaryFormatKind", + "LLVM ProfileSummary format kinds", [ + LLVM_ProfileSummaryFormatSampleProfile, + LLVM_ProfileSummaryFormatInstrProf, + LLVM_ProfileSummaryFormatCSInstrProf, + ]> { + let cppNamespace = "::mlir::LLVM"; +} + //===----------------------------------------------------------------------===// // UWTableKind //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp index ef689e3721d91..d5815d39b364b 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp @@ -391,15 +391,9 @@ ModuleFlagAttr::verify(function_ref emitError, } if (key == LLVMDialect::getModuleFlagKeyProfileSummaryName()) { - if (auto summaryAttr = dyn_cast(value)) { - StringRef fmt = summaryAttr.getFormat().getValue(); - if (fmt != "SampleProfile" && fmt != "InstrProf" && fmt != "CSInstrProf") - return emitError() << "'ProfileFormat' must be 'SampleProfile', " - "'InstrProf' or 'CSInstrProf'"; - } else { + if (!isa(value)) return emitError() << "'ProfileSummary' key expects a " "'#llvm.profile_summary' attribute"; - } return success(); } diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp index 37f07475b3f02..1e517ceb827ac 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp @@ -320,7 +320,8 @@ static llvm::Metadata *convertModuleFlagProfileSummaryAttr( SmallVector fmtNode{ mdb.createString("ProfileFormat"), - mdb.createString(summaryAttr.getFormat().getValue())}; + mdb.createString( + stringifyProfileSummaryFormatKind(summaryAttr.getFormat()))}; SmallVector vals = { llvm::MDTuple::get(context, fmtNode), diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp index 028e3ca5d903e..13cd9229846c9 100644 --- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp @@ -576,34 +576,32 @@ convertProfileSummaryModuleFlagValue(ModuleOp mlirModule, return tupleEntry; }; - auto getFormat = [&](const llvm::MDOperand &formatMD) -> StringAttr { + auto getFormat = [&](const llvm::MDOperand &formatMD) + -> std::optional { auto *tupleEntry = getMDTuple(formatMD); if (!tupleEntry) - return nullptr; + return std::nullopt; llvm::MDString *keyMD = dyn_cast(tupleEntry->getOperand(0)); if (!keyMD || keyMD->getString() != "ProfileFormat") { emitWarning(mlirModule.getLoc()) << "expected 'ProfileFormat' key: " << diagMD(tupleEntry->getOperand(0), llvmModule); - return nullptr; + return std::nullopt; } llvm::MDString *valMD = dyn_cast(tupleEntry->getOperand(1)); - auto formatAttr = llvm::StringSwitch(valMD->getString()) - .Case("SampleProfile", "SampleProfile") - .Case("InstrProf", "InstrProf") - .Case("CSInstrProf", "CSInstrProf") - .Default(""); - if (formatAttr.empty()) { + std::optional fmtKind = + symbolizeProfileSummaryFormatKind(valMD->getString()); + if (!fmtKind) { emitWarning(mlirModule.getLoc()) << "expected 'SampleProfile', 'InstrProf' or 'CSInstrProf' values, " "but found: " << diagMD(valMD, llvmModule); - return nullptr; + return std::nullopt; } - return StringAttr::get(mlirModule->getContext(), formatAttr); + return fmtKind; }; auto getConstantMD = [&](const llvm::MDOperand &md, StringRef matchKey, @@ -746,8 +744,9 @@ convertProfileSummaryModuleFlagValue(ModuleOp mlirModule, // Build ModuleFlagProfileSummaryAttr by sequentially fetching elements in // a fixed order: format, total count, etc. SmallVector profileSummary; - StringAttr format = getFormat(mdTuple->getOperand(summayIdx++)); - if (!format) + std::optional format = + getFormat(mdTuple->getOperand(summayIdx++)); + if (!format.has_value()) return nullptr; uint64_t totalCount = 0, maxCount = 0, maxInternalCount = 0, @@ -793,7 +792,7 @@ convertProfileSummaryModuleFlagValue(ModuleOp mlirModule, // Build the final profile summary attribute. return ModuleFlagProfileSummaryAttr::get( - mlirModule->getContext(), format, totalCount, maxCount, maxInternalCount, + mlirModule->getContext(), *format, totalCount, maxCount, maxInternalCount, maxFunctionCount, numCounts, numFunctions, isPartialProfile, partialProfileRatio ? partialProfileRatio : nullptr, detailedSummary); } diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir index bb730b28b947d..f9ea066a63624 100644 --- a/mlir/test/Dialect/LLVMIR/invalid.mlir +++ b/mlir/test/Dialect/LLVMIR/invalid.mlir @@ -1807,9 +1807,10 @@ module { // ----- -// expected-error@below {{'ProfileFormat' must be 'SampleProfile', 'InstrProf' or 'CSInstrProf'}} llvm.module_flags [#llvm.mlir.module_flag, + // expected-error@below {{failed to parse ModuleFlagAttr parameter}} >>] // ----- diff --git a/mlir/test/Dialect/LLVMIR/module-roundtrip.mlir b/mlir/test/Dialect/LLVMIR/module-roundtrip.mlir index 148b1eb87fa75..3935a1f5bc621 100644 --- a/mlir/test/Dialect/LLVMIR/module-roundtrip.mlir +++ b/mlir/test/Dialect/LLVMIR/module-roundtrip.mlir @@ -13,7 +13,7 @@ module { #llvm.cgprofile_entry ]>, #llvm.mlir.module_flag // CHECK-SAME: ]>, // CHECK-SAME: #llvm.mlir.module_flag Date: Thu, 1 May 2025 16:47:17 -0700 Subject: [PATCH 06/14] Regex tests and remove auto --- mlir/lib/Target/LLVMIR/ModuleImport.cpp | 2 +- mlir/test/Target/LLVMIR/llvmir.mlir | 30 ++++++++++++++----------- 2 files changed, 18 insertions(+), 14 deletions(-) diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp index 13cd9229846c9..91066dbc44058 100644 --- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp @@ -607,7 +607,7 @@ convertProfileSummaryModuleFlagValue(ModuleOp mlirModule, auto getConstantMD = [&](const llvm::MDOperand &md, StringRef matchKey, bool optional = false) -> llvm::ConstantAsMetadata * { - auto *tupleEntry = getMDTuple(md); + llvm::MDTuple *tupleEntry = getMDTuple(md); if (!tupleEntry) return nullptr; llvm::MDString *keyMD = dyn_cast(tupleEntry->getOperand(0)); diff --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir index 1b36dc9672f0c..854034f3ec243 100644 --- a/mlir/test/Target/LLVMIR/llvmir.mlir +++ b/mlir/test/Target/LLVMIR/llvmir.mlir @@ -2894,19 +2894,23 @@ llvm.module_flags [#llvm.mlir.module_flag >>] -// CHECK: !llvm.module.flags = !{!0, !15} - -// CHECK: !0 = !{i32 1, !"ProfileSummary", !1} -// CHECK: !1 = !{!2, !3, !4, !5, !6, !7, !8, !9, !10, !11} -// CHECK: !2 = !{!"ProfileFormat", !"InstrProf"} -// CHECK: !3 = !{!"TotalCount", i64 263646} -// CHECK: !4 = !{!"MaxCount", i64 86427} -// CHECK: !5 = !{!"MaxInternalCount", i64 86427} -// CHECK: !6 = !{!"MaxFunctionCount", i64 4691} -// CHECK: !7 = !{!"NumCounts", i64 3712} -// CHECK: !8 = !{!"NumFunctions", i64 796} -// CHECK: !9 = !{!"IsPartialProfile", i64 0} -// CHECK: !10 = !{!"PartialProfileRatio", double 0.000000e+00} +// CHECK: !llvm.module.flags = !{![[#PSUM:]], {{.*}}} + +// CHECK: ![[#PSUM]] = !{i32 1, !"ProfileSummary", ![[#SUMLIST:]]} +// CHECK: ![[#SUMLIST]] = !{![[#FMT:]], ![[#TC:]], ![[#MC:]], ![[#MIC:]], ![[#MFC:]], ![[#NC:]], ![[#NF:]], ![[#IPP:]], ![[#PPR:]], ![[#DS:]]} +// CHECK: ![[#FMT]] = !{!"ProfileFormat", !"InstrProf"} +// CHECK: ![[#TC]] = !{!"TotalCount", i64 263646} +// CHECK: ![[#MC]] = !{!"MaxCount", i64 86427} +// CHECK: ![[#MIC]] = !{!"MaxInternalCount", i64 86427} +// CHECK: ![[#MFC]] = !{!"MaxFunctionCount", i64 4691} +// CHECK: ![[#NC]] = !{!"NumCounts", i64 3712} +// CHECK: ![[#NF]] = !{!"NumFunctions", i64 796} +// CHECK: ![[#IPP]] = !{!"IsPartialProfile", i64 0} +// CHECK: ![[#PPR]] = !{!"PartialProfileRatio", double 0.000000e+00} +// CHECK: ![[#DS]] = !{!"DetailedSummary", ![[#DETAILED:]]} +// CHECK: ![[#DETAILED]] = !{![[#DS0:]], ![[#DS1:]]} +// CHECK: ![[#DS0:]] = !{i64 10000, i64 86427, i64 1} +// CHECK: ![[#DS1:]] = !{i64 100000, i64 86427, i64 1} // ----- From 0eff0a449515f8bd2f62b410108f8c80cb01f704 Mon Sep 17 00:00:00 2001 From: Bruno Cardoso Lopes Date: Thu, 1 May 2025 16:59:27 -0700 Subject: [PATCH 07/14] move some lambdas to static local functions --- mlir/lib/Target/LLVMIR/ModuleImport.cpp | 307 +++++++++++++----------- 1 file changed, 162 insertions(+), 145 deletions(-) diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp index 91066dbc44058..edd958821a135 100644 --- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp @@ -554,74 +554,152 @@ static Attribute convertCGProfileModuleFlagValue(ModuleOp mlirModule, return ArrayAttr::get(mlirModule->getContext(), cgProfile); } -static Attribute -convertProfileSummaryModuleFlagValue(ModuleOp mlirModule, - const llvm::Module *llvmModule, - llvm::MDTuple *mdTuple) { - unsigned profileNumEntries = mdTuple->getNumOperands(); - if (profileNumEntries < 8) { +static llvm::MDTuple *getTwoElementMDTuple(ModuleOp mlirModule, + const llvm::Module *llvmModule, + const llvm::MDOperand &md) { + auto *tupleEntry = dyn_cast_or_null(md); + if (!tupleEntry || tupleEntry->getNumOperands() != 2) emitWarning(mlirModule.getLoc()) - << "expected at 8 entries in 'ProfileSummary': " - << diagMD(mdTuple, llvmModule); + << "expected 2-element tuple metadata: " << diagMD(md, llvmModule); + return tupleEntry; +} + +static llvm::ConstantAsMetadata *getConstantMDFromKeyValueTuple( + ModuleOp mlirModule, const llvm::Module *llvmModule, + const llvm::MDOperand &md, StringRef matchKey, bool optional = false) { + llvm::MDTuple *tupleEntry = getTwoElementMDTuple(mlirModule, llvmModule, md); + if (!tupleEntry) + return nullptr; + llvm::MDString *keyMD = dyn_cast(tupleEntry->getOperand(0)); + if (!keyMD || keyMD->getString() != matchKey) { + if (!optional) + emitWarning(mlirModule.getLoc()) + << "expected '" << matchKey << "' key, but found: " + << diagMD(tupleEntry->getOperand(0), llvmModule); return nullptr; } - unsigned summayIdx = 0; + return dyn_cast(tupleEntry->getOperand(1)); +} - auto getMDTuple = [&](const llvm::MDOperand &md) { - auto *tupleEntry = dyn_cast_or_null(md); - if (!tupleEntry || tupleEntry->getNumOperands() != 2) - emitWarning(mlirModule.getLoc()) - << "expected 2-element tuple metadata: " << diagMD(md, llvmModule); - return tupleEntry; - }; +static bool convertInt64FromKeyValueTuple(ModuleOp mlirModule, + const llvm::Module *llvmModule, + const llvm::MDOperand &md, + StringRef matchKey, uint64_t &val) { + auto *valMD = + getConstantMDFromKeyValueTuple(mlirModule, llvmModule, md, matchKey); + if (!valMD) + return false; - auto getFormat = [&](const llvm::MDOperand &formatMD) - -> std::optional { - auto *tupleEntry = getMDTuple(formatMD); - if (!tupleEntry) - return std::nullopt; + if (auto *cstInt = dyn_cast(valMD->getValue())) { + val = cstInt->getZExtValue(); + return true; + } + + emitWarning(mlirModule.getLoc()) + << "expected integer metadata value for key '" << matchKey + << "': " << diagMD(md, llvmModule); + return false; +} + +static std::optional +convertProfileSummaryFormat(ModuleOp mlirModule, const llvm::Module *llvmModule, + const llvm::MDOperand &formatMD) { + auto *tupleEntry = getTwoElementMDTuple(mlirModule, llvmModule, formatMD); + if (!tupleEntry) + return std::nullopt; + + llvm::MDString *keyMD = dyn_cast(tupleEntry->getOperand(0)); + if (!keyMD || keyMD->getString() != "ProfileFormat") { + emitWarning(mlirModule.getLoc()) + << "expected 'ProfileFormat' key: " + << diagMD(tupleEntry->getOperand(0), llvmModule); + return std::nullopt; + } + + llvm::MDString *valMD = dyn_cast(tupleEntry->getOperand(1)); + std::optional fmtKind = + symbolizeProfileSummaryFormatKind(valMD->getString()); + if (!fmtKind) { + emitWarning(mlirModule.getLoc()) + << "expected 'SampleProfile', 'InstrProf' or 'CSInstrProf' values, " + "but found: " + << diagMD(valMD, llvmModule); + return std::nullopt; + } + + return fmtKind; +} + +static bool convertProfileSummaryDetailed( + ModuleOp mlirModule, const llvm::Module *llvmModule, + const llvm::MDOperand &summaryMD, + SmallVectorImpl &detailedSummary) { + auto *tupleEntry = getTwoElementMDTuple(mlirModule, llvmModule, summaryMD); + if (!tupleEntry) + return false; + + llvm::MDString *keyMD = dyn_cast(tupleEntry->getOperand(0)); + if (!keyMD || keyMD->getString() != "DetailedSummary") { + emitWarning(mlirModule.getLoc()) + << "expected 'DetailedSummary' key: " + << diagMD(tupleEntry->getOperand(0), llvmModule); + return false; + } - llvm::MDString *keyMD = dyn_cast(tupleEntry->getOperand(0)); - if (!keyMD || keyMD->getString() != "ProfileFormat") { + llvm::MDTuple *entriesMD = dyn_cast(tupleEntry->getOperand(1)); + if (!entriesMD) { + emitWarning(mlirModule.getLoc()) + << "expected tuple value for 'DetailedSummary' key: " + << diagMD(tupleEntry->getOperand(1), llvmModule); + return false; + } + + for (auto &&entry : entriesMD->operands()) { + llvm::MDTuple *entryMD = dyn_cast(entry); + if (!entryMD || entryMD->getNumOperands() != 3) { emitWarning(mlirModule.getLoc()) - << "expected 'ProfileFormat' key: " - << diagMD(tupleEntry->getOperand(0), llvmModule); - return std::nullopt; + << "'DetailedSummary' entry expects 3 operands: " + << diagMD(entry, llvmModule); + return false; } - - llvm::MDString *valMD = dyn_cast(tupleEntry->getOperand(1)); - std::optional fmtKind = - symbolizeProfileSummaryFormatKind(valMD->getString()); - if (!fmtKind) { + llvm::ConstantAsMetadata *op0 = + dyn_cast(entryMD->getOperand(0)); + llvm::ConstantAsMetadata *op1 = + dyn_cast(entryMD->getOperand(1)); + llvm::ConstantAsMetadata *op2 = + dyn_cast(entryMD->getOperand(2)); + + if (!op0 || !op1 || !op2) { emitWarning(mlirModule.getLoc()) - << "expected 'SampleProfile', 'InstrProf' or 'CSInstrProf' values, " - "but found: " - << diagMD(valMD, llvmModule); - return std::nullopt; + << "expected only integer entries in 'DetailedSummary': " + << diagMD(entry, llvmModule); + return false; } - return fmtKind; - }; - - auto getConstantMD = [&](const llvm::MDOperand &md, StringRef matchKey, - bool optional = - false) -> llvm::ConstantAsMetadata * { - llvm::MDTuple *tupleEntry = getMDTuple(md); - if (!tupleEntry) - return nullptr; - llvm::MDString *keyMD = dyn_cast(tupleEntry->getOperand(0)); - if (!keyMD || keyMD->getString() != matchKey) { - if (!optional) - emitWarning(mlirModule.getLoc()) - << "expected '" << matchKey << "' key, but found: " - << diagMD(tupleEntry->getOperand(0), llvmModule); - return nullptr; - } + auto detaildSummaryEntry = ModuleFlagProfileSummaryDetailedAttr::get( + mlirModule->getContext(), + cast(op0->getValue())->getZExtValue(), + cast(op1->getValue())->getZExtValue(), + cast(op2->getValue())->getZExtValue()); + detailedSummary.push_back(detaildSummaryEntry); + } + return true; +} - return dyn_cast(tupleEntry->getOperand(1)); - }; +static Attribute +convertProfileSummaryModuleFlagValue(ModuleOp mlirModule, + const llvm::Module *llvmModule, + llvm::MDTuple *mdTuple) { + unsigned profileNumEntries = mdTuple->getNumOperands(); + if (profileNumEntries < 8) { + emitWarning(mlirModule.getLoc()) + << "expected at 8 entries in 'ProfileSummary': " + << diagMD(mdTuple, llvmModule); + return nullptr; + } + unsigned summayIdx = 0; auto checkOptionalPosition = [&](const llvm::MDOperand &md, StringRef matchKey) -> LogicalResult { // Make sure we won't step over the bound of the array of summary entries. @@ -637,31 +715,16 @@ convertProfileSummaryModuleFlagValue(ModuleOp mlirModule, return success(); }; - auto getInt64Value = [&](const llvm::MDOperand &md, StringRef matchKey, - uint64_t &val) { - auto *valMD = getConstantMD(md, matchKey); - if (!valMD) - return false; - - if (auto *cstInt = dyn_cast(valMD->getValue())) { - val = cstInt->getZExtValue(); - return true; - } - - emitWarning(mlirModule.getLoc()) - << "expected integer metadata value for key '" << matchKey - << "': " << diagMD(md, llvmModule); - return false; - }; - auto getOptIntValue = [&](const llvm::MDOperand &md, StringRef matchKey, std::optional &val) -> LogicalResult { - if (!getConstantMD(md, matchKey, /*optional=*/true)) + if (!getConstantMDFromKeyValueTuple(mlirModule, llvmModule, md, matchKey, + /*optional=*/true)) return success(); if (checkOptionalPosition(md, matchKey).failed()) return failure(); uint64_t tmpVal = 0; - if (!getInt64Value(md, matchKey, tmpVal)) + if (!convertInt64FromKeyValueTuple(mlirModule, llvmModule, md, matchKey, + tmpVal)) return failure(); val = tmpVal; return success(); @@ -669,7 +732,8 @@ convertProfileSummaryModuleFlagValue(ModuleOp mlirModule, auto getOptDoubleValue = [&](const llvm::MDOperand &md, StringRef matchKey, FloatAttr &attr) -> LogicalResult { - auto *valMD = getConstantMD(md, matchKey, /*optional=*/true); + auto *valMD = getConstantMDFromKeyValueTuple(mlirModule, llvmModule, md, + matchKey, /*optional=*/true); if (!valMD) return success(); if (auto *cstFP = dyn_cast(valMD->getValue())) { @@ -685,87 +749,39 @@ convertProfileSummaryModuleFlagValue(ModuleOp mlirModule, return failure(); }; - auto getSummary = [&](const llvm::MDOperand &summaryMD, - SmallVectorImpl - &detailedSummary) { - auto *tupleEntry = getMDTuple(summaryMD); - if (!tupleEntry) - return false; - - llvm::MDString *keyMD = dyn_cast(tupleEntry->getOperand(0)); - if (!keyMD || keyMD->getString() != "DetailedSummary") { - emitWarning(mlirModule.getLoc()) - << "expected 'DetailedSummary' key: " - << diagMD(tupleEntry->getOperand(0), llvmModule); - return false; - } - - llvm::MDTuple *entriesMD = - dyn_cast(tupleEntry->getOperand(1)); - if (!entriesMD) { - emitWarning(mlirModule.getLoc()) - << "expected tuple value for 'DetailedSummary' key: " - << diagMD(tupleEntry->getOperand(1), llvmModule); - return false; - } - - for (auto &&entry : entriesMD->operands()) { - llvm::MDTuple *entryMD = dyn_cast(entry); - if (!entryMD || entryMD->getNumOperands() != 3) { - emitWarning(mlirModule.getLoc()) - << "'DetailedSummary' entry expects 3 operands: " - << diagMD(entry, llvmModule); - return false; - } - llvm::ConstantAsMetadata *op0 = - dyn_cast(entryMD->getOperand(0)); - llvm::ConstantAsMetadata *op1 = - dyn_cast(entryMD->getOperand(1)); - llvm::ConstantAsMetadata *op2 = - dyn_cast(entryMD->getOperand(2)); - - if (!op0 || !op1 || !op2) { - emitWarning(mlirModule.getLoc()) - << "expected only integer entries in 'DetailedSummary': " - << diagMD(entry, llvmModule); - return false; - } - - auto detaildSummaryEntry = ModuleFlagProfileSummaryDetailedAttr::get( - mlirModule->getContext(), - cast(op0->getValue())->getZExtValue(), - cast(op1->getValue())->getZExtValue(), - cast(op2->getValue())->getZExtValue()); - detailedSummary.push_back(detaildSummaryEntry); - } - return true; - }; - // Build ModuleFlagProfileSummaryAttr by sequentially fetching elements in // a fixed order: format, total count, etc. SmallVector profileSummary; - std::optional format = - getFormat(mdTuple->getOperand(summayIdx++)); + std::optional format = convertProfileSummaryFormat( + mlirModule, llvmModule, mdTuple->getOperand(summayIdx++)); if (!format.has_value()) return nullptr; uint64_t totalCount = 0, maxCount = 0, maxInternalCount = 0, maxFunctionCount = 0, numCounts = 0, numFunctions = 0; - if (!getInt64Value(mdTuple->getOperand(summayIdx++), "TotalCount", - totalCount)) + if (!convertInt64FromKeyValueTuple(mlirModule, llvmModule, + mdTuple->getOperand(summayIdx++), + "TotalCount", totalCount)) return nullptr; - if (!getInt64Value(mdTuple->getOperand(summayIdx++), "MaxCount", maxCount)) + if (!convertInt64FromKeyValueTuple(mlirModule, llvmModule, + mdTuple->getOperand(summayIdx++), + "MaxCount", maxCount)) return nullptr; - if (!getInt64Value(mdTuple->getOperand(summayIdx++), "MaxInternalCount", - maxInternalCount)) + if (!convertInt64FromKeyValueTuple(mlirModule, llvmModule, + mdTuple->getOperand(summayIdx++), + "MaxInternalCount", maxInternalCount)) return nullptr; - if (!getInt64Value(mdTuple->getOperand(summayIdx++), "MaxFunctionCount", - maxFunctionCount)) + if (!convertInt64FromKeyValueTuple(mlirModule, llvmModule, + mdTuple->getOperand(summayIdx++), + "MaxFunctionCount", maxFunctionCount)) return nullptr; - if (!getInt64Value(mdTuple->getOperand(summayIdx++), "NumCounts", numCounts)) + if (!convertInt64FromKeyValueTuple(mlirModule, llvmModule, + mdTuple->getOperand(summayIdx++), + "NumCounts", numCounts)) return nullptr; - if (!getInt64Value(mdTuple->getOperand(summayIdx++), "NumFunctions", - numFunctions)) + if (!convertInt64FromKeyValueTuple(mlirModule, llvmModule, + mdTuple->getOperand(summayIdx++), + "NumFunctions", numFunctions)) return nullptr; // Handle optional keys. @@ -786,15 +802,16 @@ convertProfileSummaryModuleFlagValue(ModuleOp mlirModule, summayIdx++; // Handle detailed summary. - SmallVector detailedSummary; - if (!getSummary(mdTuple->getOperand(summayIdx), detailedSummary)) + SmallVector detailed; + if (!convertProfileSummaryDetailed(mlirModule, llvmModule, + mdTuple->getOperand(summayIdx), detailed)) return nullptr; // Build the final profile summary attribute. return ModuleFlagProfileSummaryAttr::get( mlirModule->getContext(), *format, totalCount, maxCount, maxInternalCount, maxFunctionCount, numCounts, numFunctions, isPartialProfile, - partialProfileRatio ? partialProfileRatio : nullptr, detailedSummary); + partialProfileRatio ? partialProfileRatio : nullptr, detailed); } /// Invoke specific handlers for each known module flag value, returns nullptr From 6f203a2bf89773f7bd132c300e9d91b1a14ad718 Mon Sep 17 00:00:00 2001 From: Bruno Cardoso Lopes Date: Fri, 2 May 2025 15:14:05 -0700 Subject: [PATCH 08/14] Use FailureOr --- mlir/lib/Target/LLVMIR/ModuleImport.cpp | 105 +++++++++++++----------- 1 file changed, 55 insertions(+), 50 deletions(-) diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp index edd958821a135..5d4190022e5b3 100644 --- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp @@ -570,7 +570,7 @@ static llvm::ConstantAsMetadata *getConstantMDFromKeyValueTuple( llvm::MDTuple *tupleEntry = getTwoElementMDTuple(mlirModule, llvmModule, md); if (!tupleEntry) return nullptr; - llvm::MDString *keyMD = dyn_cast(tupleEntry->getOperand(0)); + auto *keyMD = dyn_cast(tupleEntry->getOperand(0)); if (!keyMD || keyMD->getString() != matchKey) { if (!optional) emitWarning(mlirModule.getLoc()) @@ -582,24 +582,22 @@ static llvm::ConstantAsMetadata *getConstantMDFromKeyValueTuple( return dyn_cast(tupleEntry->getOperand(1)); } -static bool convertInt64FromKeyValueTuple(ModuleOp mlirModule, - const llvm::Module *llvmModule, - const llvm::MDOperand &md, - StringRef matchKey, uint64_t &val) { +static FailureOr +convertInt64FromKeyValueTuple(ModuleOp mlirModule, + const llvm::Module *llvmModule, + const llvm::MDOperand &md, StringRef matchKey) { auto *valMD = getConstantMDFromKeyValueTuple(mlirModule, llvmModule, md, matchKey); if (!valMD) - return false; + return failure(); - if (auto *cstInt = dyn_cast(valMD->getValue())) { - val = cstInt->getZExtValue(); - return true; - } + if (auto *cstInt = dyn_cast(valMD->getValue())) + return cstInt->getZExtValue(); emitWarning(mlirModule.getLoc()) << "expected integer metadata value for key '" << matchKey << "': " << diagMD(md, llvmModule); - return false; + return failure(); } static std::optional @@ -631,20 +629,20 @@ convertProfileSummaryFormat(ModuleOp mlirModule, const llvm::Module *llvmModule, return fmtKind; } -static bool convertProfileSummaryDetailed( - ModuleOp mlirModule, const llvm::Module *llvmModule, - const llvm::MDOperand &summaryMD, - SmallVectorImpl &detailedSummary) { +static FailureOr> +convertProfileSummaryDetailed(ModuleOp mlirModule, + const llvm::Module *llvmModule, + const llvm::MDOperand &summaryMD) { auto *tupleEntry = getTwoElementMDTuple(mlirModule, llvmModule, summaryMD); if (!tupleEntry) - return false; + return failure(); llvm::MDString *keyMD = dyn_cast(tupleEntry->getOperand(0)); if (!keyMD || keyMD->getString() != "DetailedSummary") { emitWarning(mlirModule.getLoc()) << "expected 'DetailedSummary' key: " << diagMD(tupleEntry->getOperand(0), llvmModule); - return false; + return failure(); } llvm::MDTuple *entriesMD = dyn_cast(tupleEntry->getOperand(1)); @@ -652,16 +650,17 @@ static bool convertProfileSummaryDetailed( emitWarning(mlirModule.getLoc()) << "expected tuple value for 'DetailedSummary' key: " << diagMD(tupleEntry->getOperand(1), llvmModule); - return false; + return failure(); } + SmallVector detailedSummary; for (auto &&entry : entriesMD->operands()) { llvm::MDTuple *entryMD = dyn_cast(entry); if (!entryMD || entryMD->getNumOperands() != 3) { emitWarning(mlirModule.getLoc()) << "'DetailedSummary' entry expects 3 operands: " << diagMD(entry, llvmModule); - return false; + return failure(); } llvm::ConstantAsMetadata *op0 = dyn_cast(entryMD->getOperand(0)); @@ -674,7 +673,7 @@ static bool convertProfileSummaryDetailed( emitWarning(mlirModule.getLoc()) << "expected only integer entries in 'DetailedSummary': " << diagMD(entry, llvmModule); - return false; + return failure(); } auto detaildSummaryEntry = ModuleFlagProfileSummaryDetailedAttr::get( @@ -684,7 +683,7 @@ static bool convertProfileSummaryDetailed( cast(op2->getValue())->getZExtValue()); detailedSummary.push_back(detaildSummaryEntry); } - return true; + return detailedSummary; } static Attribute @@ -722,9 +721,9 @@ convertProfileSummaryModuleFlagValue(ModuleOp mlirModule, return success(); if (checkOptionalPosition(md, matchKey).failed()) return failure(); - uint64_t tmpVal = 0; - if (!convertInt64FromKeyValueTuple(mlirModule, llvmModule, md, matchKey, - tmpVal)) + FailureOr tmpVal = + convertInt64FromKeyValueTuple(mlirModule, llvmModule, md, matchKey); + if (failed(tmpVal)) return failure(); val = tmpVal; return success(); @@ -757,31 +756,36 @@ convertProfileSummaryModuleFlagValue(ModuleOp mlirModule, if (!format.has_value()) return nullptr; - uint64_t totalCount = 0, maxCount = 0, maxInternalCount = 0, - maxFunctionCount = 0, numCounts = 0, numFunctions = 0; - if (!convertInt64FromKeyValueTuple(mlirModule, llvmModule, - mdTuple->getOperand(summayIdx++), - "TotalCount", totalCount)) + FailureOr totalCount = convertInt64FromKeyValueTuple( + mlirModule, llvmModule, mdTuple->getOperand(summayIdx++), "TotalCount"); + if (failed(totalCount)) return nullptr; - if (!convertInt64FromKeyValueTuple(mlirModule, llvmModule, - mdTuple->getOperand(summayIdx++), - "MaxCount", maxCount)) + + FailureOr maxCount = convertInt64FromKeyValueTuple( + mlirModule, llvmModule, mdTuple->getOperand(summayIdx++), "MaxCount"); + if (failed(maxCount)) return nullptr; - if (!convertInt64FromKeyValueTuple(mlirModule, llvmModule, - mdTuple->getOperand(summayIdx++), - "MaxInternalCount", maxInternalCount)) + + FailureOr maxInternalCount = convertInt64FromKeyValueTuple( + mlirModule, llvmModule, mdTuple->getOperand(summayIdx++), + "MaxInternalCount"); + if (failed(maxInternalCount)) return nullptr; - if (!convertInt64FromKeyValueTuple(mlirModule, llvmModule, - mdTuple->getOperand(summayIdx++), - "MaxFunctionCount", maxFunctionCount)) + + FailureOr maxFunctionCount = convertInt64FromKeyValueTuple( + mlirModule, llvmModule, mdTuple->getOperand(summayIdx++), + "MaxFunctionCount"); + if (failed(maxFunctionCount)) return nullptr; - if (!convertInt64FromKeyValueTuple(mlirModule, llvmModule, - mdTuple->getOperand(summayIdx++), - "NumCounts", numCounts)) + + FailureOr numCounts = convertInt64FromKeyValueTuple( + mlirModule, llvmModule, mdTuple->getOperand(summayIdx++), "NumCounts"); + if (failed(numCounts)) return nullptr; - if (!convertInt64FromKeyValueTuple(mlirModule, llvmModule, - mdTuple->getOperand(summayIdx++), - "NumFunctions", numFunctions)) + + FailureOr numFunctions = convertInt64FromKeyValueTuple( + mlirModule, llvmModule, mdTuple->getOperand(summayIdx++), "NumFunctions"); + if (failed(numFunctions)) return nullptr; // Handle optional keys. @@ -802,16 +806,17 @@ convertProfileSummaryModuleFlagValue(ModuleOp mlirModule, summayIdx++; // Handle detailed summary. - SmallVector detailed; - if (!convertProfileSummaryDetailed(mlirModule, llvmModule, - mdTuple->getOperand(summayIdx), detailed)) + FailureOr> detailed = + convertProfileSummaryDetailed(mlirModule, llvmModule, + mdTuple->getOperand(summayIdx)); + if (failed(detailed)) return nullptr; // Build the final profile summary attribute. return ModuleFlagProfileSummaryAttr::get( - mlirModule->getContext(), *format, totalCount, maxCount, maxInternalCount, - maxFunctionCount, numCounts, numFunctions, isPartialProfile, - partialProfileRatio ? partialProfileRatio : nullptr, detailed); + mlirModule->getContext(), *format, *totalCount, *maxCount, + *maxInternalCount, *maxFunctionCount, *numCounts, *numFunctions, + isPartialProfile, partialProfileRatio, *detailed); } /// Invoke specific handlers for each known module flag value, returns nullptr From 181fb6bf741d03ed2fc3b01e1155e187a2e986b6 Mon Sep 17 00:00:00 2001 From: Bruno Cardoso Lopes Date: Fri, 2 May 2025 15:52:20 -0700 Subject: [PATCH 09/14] more nits and cleanup --- .../Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp | 5 +++-- mlir/test/Dialect/LLVMIR/module-roundtrip.mlir | 2 -- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp index 1e517ceb827ac..e57aecd13916f 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp @@ -347,7 +347,8 @@ static llvm::Metadata *convertModuleFlagProfileSummaryAttr( } SmallVector detailedEntries; - for (auto detailedEntry : summaryAttr.getDetailedSummary()) { + for (ModuleFlagProfileSummaryDetailedAttr detailedEntry : + summaryAttr.getDetailedSummary()) { SmallVector tupleNodes{ mdb.createConstant(llvm::ConstantInt::get( llvm::Type::getInt64Ty(context), detailedEntry.getCutOff())), @@ -385,7 +386,7 @@ static void convertModuleFlagsOp(ArrayAttr flags, llvm::IRBuilderBase &builder, arrayAttr, builder, moduleTranslation); }) - .Case([&](auto summaryAttr) { + .Case([&](ModuleFlagProfileSummaryAttr summaryAttr) { return convertModuleFlagProfileSummaryAttr( flagAttr.getKey().getValue(), summaryAttr, builder, moduleTranslation); diff --git a/mlir/test/Dialect/LLVMIR/module-roundtrip.mlir b/mlir/test/Dialect/LLVMIR/module-roundtrip.mlir index 3935a1f5bc621..85abd57df53c8 100644 --- a/mlir/test/Dialect/LLVMIR/module-roundtrip.mlir +++ b/mlir/test/Dialect/LLVMIR/module-roundtrip.mlir @@ -46,5 +46,3 @@ module { // CHECK-SAME: , // CHECK-SAME: // CHECK-SAME: >>] - -llvm.module_flags [] From 9052f26607a39adad887a482db1b006f9e0c5b8f Mon Sep 17 00:00:00 2001 From: Bruno Cardoso Lopes Date: Fri, 2 May 2025 16:00:01 -0700 Subject: [PATCH 10/14] add doc to helper functions --- mlir/lib/Target/LLVMIR/ModuleImport.cpp | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp index 5d4190022e5b3..383c3963101d7 100644 --- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp @@ -554,6 +554,8 @@ static Attribute convertCGProfileModuleFlagValue(ModuleOp mlirModule, return ArrayAttr::get(mlirModule->getContext(), cgProfile); } +/// Extract a two element `MDTuple` from a `MDOperand`. Emit a warning in case +/// something else is found. static llvm::MDTuple *getTwoElementMDTuple(ModuleOp mlirModule, const llvm::Module *llvmModule, const llvm::MDOperand &md) { @@ -564,6 +566,9 @@ static llvm::MDTuple *getTwoElementMDTuple(ModuleOp mlirModule, return tupleEntry; } +/// Extract a constant metadata value from a two element tuple (). +/// Return nullptr if requirements are not met. A warning is emitted if the +/// `matchKey` is different from the tuple's key. static llvm::ConstantAsMetadata *getConstantMDFromKeyValueTuple( ModuleOp mlirModule, const llvm::Module *llvmModule, const llvm::MDOperand &md, StringRef matchKey, bool optional = false) { @@ -582,6 +587,9 @@ static llvm::ConstantAsMetadata *getConstantMDFromKeyValueTuple( return dyn_cast(tupleEntry->getOperand(1)); } +/// Extract an integer value from a two element tuple (). +/// Fail if requirements are not met. A warning is emitted if the +/// found value isn't a LLVM constant integer. static FailureOr convertInt64FromKeyValueTuple(ModuleOp mlirModule, const llvm::Module *llvmModule, From c5056ac767364f98dfc07fc74149c325ac1ae974 Mon Sep 17 00:00:00 2001 From: Bruno Cardoso Lopes Date: Fri, 2 May 2025 16:04:17 -0700 Subject: [PATCH 11/14] hoist type creation --- .../LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp index e57aecd13916f..82bdc51145d1c 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp @@ -347,15 +347,16 @@ static llvm::Metadata *convertModuleFlagProfileSummaryAttr( } SmallVector detailedEntries; + llvm::Type *llvmInt64Type = llvm::Type::getInt64Ty(context); for (ModuleFlagProfileSummaryDetailedAttr detailedEntry : summaryAttr.getDetailedSummary()) { SmallVector tupleNodes{ + mdb.createConstant( + llvm::ConstantInt::get(llvmInt64Type, detailedEntry.getCutOff())), + mdb.createConstant( + llvm::ConstantInt::get(llvmInt64Type, detailedEntry.getMinCount())), mdb.createConstant(llvm::ConstantInt::get( - llvm::Type::getInt64Ty(context), detailedEntry.getCutOff())), - mdb.createConstant(llvm::ConstantInt::get( - llvm::Type::getInt64Ty(context), detailedEntry.getMinCount())), - mdb.createConstant(llvm::ConstantInt::get( - llvm::Type::getInt64Ty(context), detailedEntry.getNumCounts()))}; + llvmInt64Type, detailedEntry.getNumCounts()))}; detailedEntries.push_back(llvm::MDTuple::get(context, tupleNodes)); } SmallVector detailedSummary{ From b1096b77661aa2477e528bdf9ac970409e69ec8a Mon Sep 17 00:00:00 2001 From: Bruno Cardoso Lopes Date: Fri, 2 May 2025 16:13:43 -0700 Subject: [PATCH 12/14] Use FailureOr for getOptIntValue --- mlir/lib/Target/LLVMIR/ModuleImport.cpp | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp index 383c3963101d7..4a0b4c29f9dc4 100644 --- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp @@ -41,6 +41,7 @@ #include "llvm/IR/Metadata.h" #include "llvm/IR/Operator.h" #include "llvm/Support/ModRef.h" +#include using namespace mlir; using namespace mlir::LLVM; @@ -722,11 +723,13 @@ convertProfileSummaryModuleFlagValue(ModuleOp mlirModule, return success(); }; - auto getOptIntValue = [&](const llvm::MDOperand &md, StringRef matchKey, - std::optional &val) -> LogicalResult { + auto getOptIntValue = + [&](const llvm::MDOperand &md, + StringRef matchKey) -> FailureOr> { + std::optional val = std::nullopt; if (!getConstantMDFromKeyValueTuple(mlirModule, llvmModule, md, matchKey, /*optional=*/true)) - return success(); + return val; if (checkOptionalPosition(md, matchKey).failed()) return failure(); FailureOr tmpVal = @@ -734,7 +737,7 @@ convertProfileSummaryModuleFlagValue(ModuleOp mlirModule, if (failed(tmpVal)) return failure(); val = tmpVal; - return success(); + return val; }; auto getOptDoubleValue = [&](const llvm::MDOperand &md, StringRef matchKey, @@ -797,12 +800,11 @@ convertProfileSummaryModuleFlagValue(ModuleOp mlirModule, return nullptr; // Handle optional keys. - std::optional isPartialProfile; - if (getOptIntValue(mdTuple->getOperand(summayIdx), "IsPartialProfile", - isPartialProfile) - .failed()) + FailureOr> isPartialProfile = + getOptIntValue(mdTuple->getOperand(summayIdx), "IsPartialProfile"); + if (failed(isPartialProfile)) return nullptr; - if (isPartialProfile.has_value()) + if (isPartialProfile->has_value()) summayIdx++; FloatAttr partialProfileRatio; @@ -824,7 +826,7 @@ convertProfileSummaryModuleFlagValue(ModuleOp mlirModule, return ModuleFlagProfileSummaryAttr::get( mlirModule->getContext(), *format, *totalCount, *maxCount, *maxInternalCount, *maxFunctionCount, *numCounts, *numFunctions, - isPartialProfile, partialProfileRatio, *detailed); + *isPartialProfile, partialProfileRatio, *detailed); } /// Invoke specific handlers for each known module flag value, returns nullptr From 6f0bd126cdbad229356c4c39cae2c289a9eaffe1 Mon Sep 17 00:00:00 2001 From: Bruno Cardoso Lopes Date: Fri, 2 May 2025 16:26:31 -0700 Subject: [PATCH 13/14] Use FailureOr for getOptDoubleValue --- mlir/lib/Target/LLVMIR/ModuleImport.cpp | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp index 4a0b4c29f9dc4..8ca5576900772 100644 --- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp @@ -740,18 +740,17 @@ convertProfileSummaryModuleFlagValue(ModuleOp mlirModule, return val; }; - auto getOptDoubleValue = [&](const llvm::MDOperand &md, StringRef matchKey, - FloatAttr &attr) -> LogicalResult { + auto getOptDoubleValue = [&](const llvm::MDOperand &md, + StringRef matchKey) -> FailureOr { auto *valMD = getConstantMDFromKeyValueTuple(mlirModule, llvmModule, md, matchKey, /*optional=*/true); if (!valMD) - return success(); + return FloatAttr{}; if (auto *cstFP = dyn_cast(valMD->getValue())) { if (checkOptionalPosition(md, matchKey).failed()) return failure(); - attr = FloatAttr::get(Float64Type::get(mlirModule.getContext()), + return FloatAttr::get(Float64Type::get(mlirModule.getContext()), cstFP->getValueAPF()); - return success(); } emitWarning(mlirModule.getLoc()) << "expected double metadata value for key '" << matchKey @@ -807,12 +806,11 @@ convertProfileSummaryModuleFlagValue(ModuleOp mlirModule, if (isPartialProfile->has_value()) summayIdx++; - FloatAttr partialProfileRatio; - if (getOptDoubleValue(mdTuple->getOperand(summayIdx), "PartialProfileRatio", - partialProfileRatio) - .failed()) + FailureOr partialProfileRatio = + getOptDoubleValue(mdTuple->getOperand(summayIdx), "PartialProfileRatio"); + if (failed(partialProfileRatio)) return nullptr; - if (partialProfileRatio) + if (*partialProfileRatio) summayIdx++; // Handle detailed summary. @@ -826,7 +824,7 @@ convertProfileSummaryModuleFlagValue(ModuleOp mlirModule, return ModuleFlagProfileSummaryAttr::get( mlirModule->getContext(), *format, *totalCount, *maxCount, *maxInternalCount, *maxFunctionCount, *numCounts, *numFunctions, - *isPartialProfile, partialProfileRatio, *detailed); + *isPartialProfile, *partialProfileRatio, *detailed); } /// Invoke specific handlers for each known module flag value, returns nullptr From a86a3b7fb689f67d555afc5ea665afb5249c2644 Mon Sep 17 00:00:00 2001 From: Bruno Cardoso Lopes Date: Mon, 5 May 2025 11:36:26 -0700 Subject: [PATCH 14/14] address last round of reviews --- .../LLVMIR/LLVMToLLVMIRTranslation.cpp | 1 - mlir/lib/Target/LLVMIR/ModuleImport.cpp | 19 +++++++------------ 2 files changed, 7 insertions(+), 13 deletions(-) diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp index 82bdc51145d1c..4ea313019f34d 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp @@ -300,7 +300,6 @@ convertModuleFlagValue(StringRef key, ArrayAttr arrayAttr, } return llvm::MDTuple::getDistinct(context, nodes); } - return nullptr; } diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp index 8ca5576900772..6f56a17ecd4e3 100644 --- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp @@ -595,7 +595,7 @@ static FailureOr convertInt64FromKeyValueTuple(ModuleOp mlirModule, const llvm::Module *llvmModule, const llvm::MDOperand &md, StringRef matchKey) { - auto *valMD = + llvm::ConstantAsMetadata *valMD = getConstantMDFromKeyValueTuple(mlirModule, llvmModule, md, matchKey); if (!valMD) return failure(); @@ -671,13 +671,10 @@ convertProfileSummaryDetailed(ModuleOp mlirModule, << diagMD(entry, llvmModule); return failure(); } - llvm::ConstantAsMetadata *op0 = - dyn_cast(entryMD->getOperand(0)); - llvm::ConstantAsMetadata *op1 = - dyn_cast(entryMD->getOperand(1)); - llvm::ConstantAsMetadata *op2 = - dyn_cast(entryMD->getOperand(2)); + auto *op0 = dyn_cast(entryMD->getOperand(0)); + auto *op1 = dyn_cast(entryMD->getOperand(1)); + auto *op2 = dyn_cast(entryMD->getOperand(2)); if (!op0 || !op1 || !op2) { emitWarning(mlirModule.getLoc()) << "expected only integer entries in 'DetailedSummary': " @@ -726,17 +723,15 @@ convertProfileSummaryModuleFlagValue(ModuleOp mlirModule, auto getOptIntValue = [&](const llvm::MDOperand &md, StringRef matchKey) -> FailureOr> { - std::optional val = std::nullopt; if (!getConstantMDFromKeyValueTuple(mlirModule, llvmModule, md, matchKey, /*optional=*/true)) - return val; + return FailureOr>(std::nullopt); if (checkOptionalPosition(md, matchKey).failed()) return failure(); - FailureOr tmpVal = + FailureOr val = convertInt64FromKeyValueTuple(mlirModule, llvmModule, md, matchKey); - if (failed(tmpVal)) + if (failed(val)) return failure(); - val = tmpVal; return val; };