Skip to content

Conversation

@bcardosolopes
Copy link
Member

@bcardosolopes bcardosolopes commented May 1, 2025

Add one more of these module flags.

Unlike "CG Profile", LLVM proper does not verify the content of the metadata, but returns a nullptr in case it's ill-formed (it's up to the user to take action). This prompted me to implement warning checks, preventing the importer to consume broken data.

Unlike "CG Profile", LLVM proper does not verify the content of the metadata,
but returns an empty one in case it's ill-formed. To that intent the importer
here does a significant amount of checks to avoid consuming bad content.
@llvmbot
Copy link
Member

llvmbot commented May 1, 2025

@llvm/pr-subscribers-mlir-llvm

@llvm/pr-subscribers-mlir

Author: Bruno Cardoso Lopes (bcardosolopes)

Changes

Add one more of these module flags.

Unlike "CG Profile", LLVM proper does not verify the content of the metadata, but returns a nullptr in case it's ill-formed (it's up to the user to take action). This prompted me to implement warning checks and prevent consume broken data.


Patch is 32.32 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/138070.diff

10 Files Affected:

  • (modified) mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td (+48)
  • (modified) mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.td (+3)
  • (modified) mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp (+13)
  • (modified) mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp (+68)
  • (modified) mlir/lib/Target/LLVMIR/ModuleImport.cpp (+254-5)
  • (modified) mlir/test/Dialect/LLVMIR/invalid.mlir (+7)
  • (modified) mlir/test/Dialect/LLVMIR/module-roundtrip.mlir (+24-2)
  • (modified) mlir/test/Target/LLVMIR/Import/import-failure.ll (+124)
  • (modified) mlir/test/Target/LLVMIR/Import/module-flags.ll (+59-8)
  • (modified) mlir/test/Target/LLVMIR/llvmir.mlir (+27)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
index 7d6d38ecad897..5eb66745db829 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
@@ -1378,6 +1378,54 @@ def ModuleFlagCGProfileEntryAttr
   let assemblyFormat = "`<` struct(params) `>`";
 }
 
+def ModuleFlagProfileSummaryDetailedAttr
+    : LLVM_Attr<"ModuleFlagProfileSummaryDetailed", "profile_summary_detailed"> {
+  let summary = "ProfileSummary detailed information";
+  let description = [{
+    Contains detailed information pertinent to "ProfileSummary" attribute.
+    A `#llvm.profile_summary` may contain several of it.
+    ```mlir
+    llvm.module_flags [ ...
+        detailed_summary = [
+        #llvm.profile_summary_detailed<cut_off = 10000, min_count = 86427, num_counts = 1>,
+        #llvm.profile_summary_detailed<cut_off = 100000, min_count = 86427, num_counts = 1>
+    ```
+  }];
+  let parameters = (ins "uint32_t":$cut_off,
+                        "uint64_t":$min_count,
+                        "uint32_t":$num_counts);
+  let assemblyFormat = "`<` struct(params) `>`";
+}
+
+def ModuleFlagProfileSummaryAttr
+    : LLVM_Attr<"ModuleFlagProfileSummary", "profile_summary"> {
+  let summary = "ProfileSummary module flag";
+  let description = [{
+    Describes ProfileSummary gathered data in a module. Example:
+    ```mlir
+    llvm.module_flags [#llvm.mlir.module_flag<error, "ProfileSummary",
+      #llvm.profile_summary<format = "InstrProf", total_count = 263646, max_count = 86427,
+        max_internal_count = 86427, max_function_count = 4691,
+        num_counts = 3712, num_functions = 796,
+        is_partial_profile = 0 : i64,
+        partial_profile_ratio = 0.000000e+00 : f64,
+        detailed_summary = [
+        #llvm.profile_summary_detailed<cut_off = 10000, min_count = 86427, num_counts = 1>,
+        #llvm.profile_summary_detailed<cut_off = 100000, min_count = 86427, num_counts = 1>
+    ]>>]
+    ```
+  }];
+  let parameters = (
+    ins "StringAttr":$format, "uint64_t":$total_count, "uint64_t":$max_count,
+        "uint64_t":$max_internal_count, "uint64_t":$max_function_count,
+        "uint64_t":$num_counts, "uint64_t":$num_functions,
+        OptionalParameter<"IntegerAttr">:$is_partial_profile,
+        OptionalParameter<"FloatAttr">:$partial_profile_ratio,
+        "ArrayAttr":$detailed_summary);
+
+  let assemblyFormat = "`<` struct(params) `>`";
+}
+
 //===----------------------------------------------------------------------===//
 // LLVM_DependentLibrariesAttr
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.td
index 9f9d075a3eebf..b5ea8fc5da500 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.td
@@ -92,6 +92,9 @@ def LLVM_Dialect : Dialect {
     static StringRef getModuleFlagKeyCGProfileName() {
       return "CG Profile";
     }
+    static StringRef getModuleFlagKeyProfileSummaryName() {
+      return "ProfileSummary";
+    }
 
     /// Returns `true` if the given type is compatible with the LLVM dialect.
     static bool isCompatibleType(Type);
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp
index ffde597ac83c1..ef689e3721d91 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp
@@ -390,6 +390,19 @@ ModuleFlagAttr::verify(function_ref<InFlightDiagnostic()> emitError,
     return success();
   }
 
+  if (key == LLVMDialect::getModuleFlagKeyProfileSummaryName()) {
+    if (auto summaryAttr = dyn_cast<ModuleFlagProfileSummaryAttr>(value)) {
+      StringRef fmt = summaryAttr.getFormat().getValue();
+      if (fmt != "SampleProfile" && fmt != "InstrProf" && fmt != "CSInstrProf")
+        return emitError() << "'ProfileFormat' must be 'SampleProfile', "
+                              "'InstrProf' or 'CSInstrProf'";
+    } else {
+      return emitError() << "'ProfileSummary' key expects a "
+                            "'#llvm.profile_summary' attribute";
+    }
+    return success();
+  }
+
   if (isa<IntegerAttr, StringAttr>(value))
     return success();
 
diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
index 35dcde2a33d41..260d61f97fce5 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
@@ -300,9 +300,72 @@ convertModuleFlagValue(StringRef key, ArrayAttr arrayAttr,
     }
     return llvm::MDTuple::getDistinct(context, nodes);
   }
+
   return nullptr;
 }
 
+static llvm::Metadata *convertModuleFlagProfileSummaryAttr(
+    StringRef key, ModuleFlagProfileSummaryAttr summaryAttr,
+    llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation) {
+  llvm::LLVMContext &context = builder.getContext();
+  llvm::MDBuilder mdb(context);
+  SmallVector<llvm::Metadata *> summaryNodes;
+
+  auto getIntTuple = [&](StringRef key, uint64_t val) -> llvm::MDTuple * {
+    SmallVector<llvm::Metadata *> tupleNodes{
+        mdb.createString(key), mdb.createConstant(llvm::ConstantInt::get(
+                                   llvm::Type::getInt64Ty(context), val))};
+    return llvm::MDTuple::get(context, tupleNodes);
+  };
+
+  SmallVector<llvm::Metadata *> fmtNode{
+      mdb.createString("ProfileFormat"),
+      mdb.createString(summaryAttr.getFormat().getValue())};
+
+  SmallVector<llvm::Metadata *> vals = {
+      llvm::MDTuple::get(context, fmtNode),
+      getIntTuple("TotalCount", summaryAttr.getTotalCount()),
+      getIntTuple("MaxCount", summaryAttr.getMaxCount()),
+      getIntTuple("MaxInternalCount", summaryAttr.getMaxInternalCount()),
+      getIntTuple("MaxFunctionCount", summaryAttr.getMaxFunctionCount()),
+      getIntTuple("NumCounts", summaryAttr.getNumCounts()),
+      getIntTuple("NumFunctions", summaryAttr.getNumFunctions()),
+  };
+
+  if (summaryAttr.getIsPartialProfile())
+    vals.push_back(getIntTuple("IsPartialProfile",
+                               summaryAttr.getIsPartialProfile().getUInt()));
+
+  if (summaryAttr.getPartialProfileRatio()) {
+    SmallVector<llvm::Metadata *> tupleNodes{
+        mdb.createString("PartialProfileRatio"),
+        mdb.createConstant(llvm::ConstantFP::get(
+            llvm::Type::getDoubleTy(context),
+            summaryAttr.getPartialProfileRatio().getValue()))};
+    vals.push_back(llvm::MDTuple::get(context, tupleNodes));
+  }
+
+  SmallVector<llvm::Metadata *> detailedEntries;
+  for (auto detailedEntry :
+       summaryAttr.getDetailedSummary()
+           .getAsRange<ModuleFlagProfileSummaryDetailedAttr>()) {
+    SmallVector<llvm::Metadata *> tupleNodes{
+        mdb.createConstant(llvm::ConstantInt::get(
+            llvm::Type::getInt64Ty(context), detailedEntry.getCutOff())),
+        mdb.createConstant(llvm::ConstantInt::get(
+            llvm::Type::getInt64Ty(context), detailedEntry.getMinCount())),
+        mdb.createConstant(llvm::ConstantInt::get(
+            llvm::Type::getInt64Ty(context), detailedEntry.getNumCounts()))};
+    detailedEntries.push_back(llvm::MDTuple::get(context, tupleNodes));
+  }
+  SmallVector<llvm::Metadata *> detailedSummary{
+      mdb.createString("DetailedSummary"),
+      llvm::MDTuple::get(context, detailedEntries)};
+  vals.push_back(llvm::MDTuple::get(context, detailedSummary));
+
+  return llvm::MDNode::get(context, vals);
+}
+
 static void convertModuleFlagsOp(ArrayAttr flags, llvm::IRBuilderBase &builder,
                                  LLVM::ModuleTranslation &moduleTranslation) {
   llvm::Module *llvmModule = moduleTranslation.getLLVMModule();
@@ -323,6 +386,11 @@ static void convertModuleFlagsOp(ArrayAttr flags, llvm::IRBuilderBase &builder,
                                             arrayAttr, builder,
                                             moduleTranslation);
             })
+            .Case<ModuleFlagProfileSummaryAttr>([&](auto summaryAttr) {
+              return convertModuleFlagProfileSummaryAttr(
+                  flagAttr.getKey().getValue(), summaryAttr, builder,
+                  moduleTranslation);
+            })
             .Default([](auto) { return nullptr; });
 
     assert(valueMetadata && "expected valid metadata");
diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index 0b77a3d23d392..fff2ae4a65f2d 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -554,13 +554,262 @@ static Attribute convertCGProfileModuleFlagValue(ModuleOp mlirModule,
   return ArrayAttr::get(mlirModule->getContext(), cgProfile);
 }
 
+static Attribute
+convertProfileSummaryModuleFlagValue(ModuleOp mlirModule,
+                                     const llvm::Module *llvmModule,
+                                     llvm::MDTuple *mdTuple) {
+  unsigned profileNumEntries = mdTuple->getNumOperands();
+  if (profileNumEntries < 8) {
+    emitWarning(mlirModule.getLoc())
+        << "expected at 8 entries in 'ProfileSummary': "
+        << diagMD(mdTuple, llvmModule);
+    return nullptr;
+  }
+
+  unsigned summayIdx = 0;
+
+  auto getMDTuple = [&](const llvm::MDOperand &md) {
+    auto *tupleEntry = dyn_cast_or_null<llvm::MDTuple>(md);
+    if (!tupleEntry || tupleEntry->getNumOperands() != 2)
+      emitWarning(mlirModule.getLoc())
+          << "expected 2-element tuple metadata: " << diagMD(md, llvmModule);
+    return tupleEntry;
+  };
+
+  auto getFormat = [&](const llvm::MDOperand &formatMD) -> StringAttr {
+    auto *tupleEntry = getMDTuple(formatMD);
+    if (!tupleEntry)
+      return nullptr;
+
+    llvm::MDString *keyMD = dyn_cast<llvm::MDString>(tupleEntry->getOperand(0));
+    if (!keyMD || keyMD->getString() != "ProfileFormat") {
+      emitWarning(mlirModule.getLoc())
+          << "expected 'ProfileFormat' key: "
+          << diagMD(tupleEntry->getOperand(0), llvmModule);
+      return nullptr;
+    }
+
+    llvm::MDString *valMD = dyn_cast<llvm::MDString>(tupleEntry->getOperand(1));
+    auto formatAttr = llvm::StringSwitch<std::string>(valMD->getString())
+                          .Case("SampleProfile", "SampleProfile")
+                          .Case("InstrProf", "InstrProf")
+                          .Case("CSInstrProf", "CSInstrProf")
+                          .Default("");
+    if (formatAttr.empty()) {
+      emitWarning(mlirModule.getLoc())
+          << "expected 'SampleProfile', 'InstrProf' or 'CSInstrProf' values, "
+             "but found: "
+          << diagMD(valMD, llvmModule);
+      return nullptr;
+    }
+
+    return StringAttr::get(mlirModule->getContext(), formatAttr);
+  };
+
+  auto getConstantMD = [&](const llvm::MDOperand &md, StringRef matchKey,
+                           bool optional =
+                               false) -> llvm::ConstantAsMetadata * {
+    auto *tupleEntry = getMDTuple(md);
+    if (!tupleEntry)
+      return nullptr;
+    llvm::MDString *keyMD = dyn_cast<llvm::MDString>(tupleEntry->getOperand(0));
+    if (!keyMD || keyMD->getString() != matchKey) {
+      if (!optional)
+        emitWarning(mlirModule.getLoc())
+            << "expected '" << matchKey << "' key, but found: "
+            << diagMD(tupleEntry->getOperand(0), llvmModule);
+      return nullptr;
+    }
+
+    return dyn_cast<llvm::ConstantAsMetadata>(tupleEntry->getOperand(1));
+  };
+
+  auto checkOptionalPosition = [&](const llvm::MDOperand &md,
+                                   StringRef matchKey) -> LogicalResult {
+    // Make sure we won't step over the bound of the array of summary entries.
+    // Since (non-optional) DetailedSummary always comes last, the next entry in
+    // the tuple operand array must exist.
+    if (summayIdx + 1 >= profileNumEntries) {
+      emitWarning(mlirModule.getLoc())
+          << "the last summary entry is '" << matchKey
+          << "', expected 'DetailedSummary': " << diagMD(md, llvmModule);
+      return failure();
+    }
+
+    return success();
+  };
+
+  auto getInt64Value = [&](const llvm::MDOperand &md, StringRef matchKey,
+                           uint64_t &val) {
+    auto *valMD = getConstantMD(md, matchKey);
+    if (!valMD)
+      return false;
+
+    if (auto *cstInt = dyn_cast<llvm::ConstantInt>(valMD->getValue())) {
+      val = cstInt->getZExtValue();
+      return true;
+    }
+
+    emitWarning(mlirModule.getLoc())
+        << "expected integer metadata value for key '" << matchKey
+        << "': " << diagMD(md, llvmModule);
+    return false;
+  };
+
+  auto getOptIntValue = [&](const llvm::MDOperand &md, StringRef matchKey,
+                            IntegerAttr &attr) -> LogicalResult {
+    if (!getConstantMD(md, matchKey, /*optional=*/true))
+      return success();
+    if (checkOptionalPosition(md, matchKey).failed())
+      return failure();
+    uint64_t val = 0;
+    if (!getInt64Value(md, matchKey, val))
+      return failure();
+    attr =
+        IntegerAttr::get(IntegerType::get(mlirModule->getContext(), 64), val);
+    return success();
+  };
+
+  auto getOptDoubleValue = [&](const llvm::MDOperand &md, StringRef matchKey,
+                               FloatAttr &attr) -> LogicalResult {
+    auto *valMD = getConstantMD(md, matchKey, /*optional=*/true);
+    if (!valMD)
+      return success();
+    if (auto *cstFP = dyn_cast<llvm::ConstantFP>(valMD->getValue())) {
+      if (checkOptionalPosition(md, matchKey).failed())
+        return failure();
+      attr = FloatAttr::get(Float64Type::get(mlirModule.getContext()),
+                            cstFP->getValueAPF());
+      return success();
+    }
+    emitWarning(mlirModule.getLoc())
+        << "expected double metadata value for key '" << matchKey
+        << "': " << diagMD(md, llvmModule);
+    return failure();
+  };
+
+  auto getSummary = [&](const llvm::MDOperand &summaryMD) -> ArrayAttr {
+    auto *tupleEntry = getMDTuple(summaryMD);
+    if (!tupleEntry)
+      return nullptr;
+
+    llvm::MDString *keyMD = dyn_cast<llvm::MDString>(tupleEntry->getOperand(0));
+    if (!keyMD || keyMD->getString() != "DetailedSummary") {
+      emitWarning(mlirModule.getLoc())
+          << "expected 'DetailedSummary' key: "
+          << diagMD(tupleEntry->getOperand(0), llvmModule);
+      return nullptr;
+    }
+
+    llvm::MDTuple *entriesMD =
+        dyn_cast<llvm::MDTuple>(tupleEntry->getOperand(1));
+    if (!entriesMD) {
+      emitWarning(mlirModule.getLoc())
+          << "expected tuple value for 'DetailedSummary' key: "
+          << diagMD(tupleEntry->getOperand(1), llvmModule);
+      return nullptr;
+    }
+
+    SmallVector<Attribute> detailedSummary;
+    for (auto &&entry : entriesMD->operands()) {
+      llvm::MDTuple *entryMD = dyn_cast<llvm::MDTuple>(entry);
+      if (!entryMD || entryMD->getNumOperands() != 3) {
+        emitWarning(mlirModule.getLoc())
+            << "'DetailedSummary' entry expects 3 operands: "
+            << diagMD(entry, llvmModule);
+        return nullptr;
+      }
+      llvm::ConstantAsMetadata *op0 =
+          dyn_cast<llvm::ConstantAsMetadata>(entryMD->getOperand(0));
+      llvm::ConstantAsMetadata *op1 =
+          dyn_cast<llvm::ConstantAsMetadata>(entryMD->getOperand(1));
+      llvm::ConstantAsMetadata *op2 =
+          dyn_cast<llvm::ConstantAsMetadata>(entryMD->getOperand(2));
+
+      if (!op0 || !op1 || !op2) {
+        emitWarning(mlirModule.getLoc())
+            << "expected only integer entries in 'DetailedSummary': "
+            << diagMD(entry, llvmModule);
+        return nullptr;
+      }
+
+      auto detaildSummaryEntry = ModuleFlagProfileSummaryDetailedAttr::get(
+          mlirModule->getContext(),
+          cast<llvm::ConstantInt>(op0->getValue())->getZExtValue(),
+          cast<llvm::ConstantInt>(op1->getValue())->getZExtValue(),
+          cast<llvm::ConstantInt>(op2->getValue())->getZExtValue());
+      detailedSummary.push_back(detaildSummaryEntry);
+    }
+    return ArrayAttr::get(mlirModule->getContext(), detailedSummary);
+  };
+
+  // Build ModuleFlagProfileSummaryAttr by sequentially fetching elements in
+  // a fixed order: format, total count, etc.
+  SmallVector<Attribute> profileSummary;
+  StringAttr format = getFormat(mdTuple->getOperand(summayIdx++));
+  if (!format)
+    return nullptr;
+
+  uint64_t totalCount = 0, maxCount = 0, maxInternalCount = 0,
+           maxFunctionCount = 0, numCounts = 0, numFunctions = 0;
+  if (!getInt64Value(mdTuple->getOperand(summayIdx++), "TotalCount",
+                     totalCount))
+    return nullptr;
+  if (!getInt64Value(mdTuple->getOperand(summayIdx++), "MaxCount", maxCount))
+    return nullptr;
+  if (!getInt64Value(mdTuple->getOperand(summayIdx++), "MaxInternalCount",
+                     maxInternalCount))
+    return nullptr;
+  if (!getInt64Value(mdTuple->getOperand(summayIdx++), "MaxFunctionCount",
+                     maxFunctionCount))
+    return nullptr;
+  if (!getInt64Value(mdTuple->getOperand(summayIdx++), "NumCounts", numCounts))
+    return nullptr;
+  if (!getInt64Value(mdTuple->getOperand(summayIdx++), "NumFunctions",
+                     numFunctions))
+    return nullptr;
+
+  // Handle optional keys.
+  IntegerAttr isPartialProfile;
+  if (getOptIntValue(mdTuple->getOperand(summayIdx), "IsPartialProfile",
+                     isPartialProfile)
+          .failed())
+    return nullptr;
+  if (isPartialProfile)
+    summayIdx++;
+
+  FloatAttr partialProfileRatio;
+  if (getOptDoubleValue(mdTuple->getOperand(summayIdx), "PartialProfileRatio",
+                        partialProfileRatio)
+          .failed())
+    return nullptr;
+  if (partialProfileRatio)
+    summayIdx++;
+
+  // Handle detailed summary.
+  ArrayAttr detailedSummary = getSummary(mdTuple->getOperand(summayIdx));
+  if (!detailedSummary)
+    return nullptr;
+
+  // Build the final profile summary attribute.
+  return ModuleFlagProfileSummaryAttr::get(
+      mlirModule->getContext(), format, totalCount, maxCount, maxInternalCount,
+      maxFunctionCount, numCounts, numFunctions,
+      isPartialProfile ? isPartialProfile : nullptr,
+      partialProfileRatio ? partialProfileRatio : nullptr, detailedSummary);
+}
+
 /// Invoke specific handlers for each known module flag value, returns nullptr
 /// if the key is unknown or unimplemented.
-static Attribute convertModuleFlagValueFromMDTuple(ModuleOp mlirModule,
-                                                   StringRef key,
-                                                   llvm::MDTuple *mdTuple) {
+static Attribute
+convertModuleFlagValueFromMDTuple(ModuleOp mlirModule,
+                                  const llvm::Module *llvmModule, StringRef key,
+                                  llvm::MDTuple *mdTuple) {
   if (key == LLVMDialect::getModuleFlagKeyCGProfileName())
     return convertCGProfileModuleFlagValue(mlirModule, mdTuple);
+  if (key == LLVMDialect::getModuleFlagKeyProfileSummaryName())
+    return convertProfileSummaryModuleFlagValue(mlirModule, llvmModule,
+                                                mdTuple);
   return nullptr;
 }
 
@@ -576,8 +825,8 @@ LogicalResult ModuleImport::convertModuleFlagsMetadata() {
     } else if (auto *mdString = dyn_cast<llvm::MDString>(val)) {
       valAttr = builder.getStringAttr(mdString->getString());
     } else if (auto *mdTuple = dyn_cast<llvm::MDTuple>(val)) {
-      valAttr = convertModuleFlagValueFromMDTuple(mlirModule, key->getString(),
-                                                  mdTuple);
+      valAttr = convertModuleFlagValueFromMDTuple(mlirModule, llvmModule.get(),
+                                                  key->getString(), mdTuple);
     }
 
     if (!valAttr) {
diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index 5dea94026b248..84c0d40c8b346 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -1800,6 +1800,13 @@ module {
 
 // -----
 
+module {
+  // expected-error@below {{'ProfileSummary' key expects a '#llvm.profile_summary' attribute}}
+  llvm.module_flags [#llvm.mlir.module_flag<append, "ProfileSummary", 3 : i64>]
+}
+
+// -----
+
 llvm.func @t0...
[truncated]

Copy link
Contributor

@gysit gysit left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for all the verification effort.

I did a pass and added comments. Not sure about introducing the enum for SampleProfile/InstrProf etc the but I imagine it could make sense.

@bcardosolopes
Copy link
Member Author

Applied all feedback, thanks!

Copy link
Contributor

@gysit gysit left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM modulo nits.

I commented in some places that return FailureOr would be nicer. I think returning output parameters is generally preferable. However, I am ok with the current solution as well.

Copy link
Contributor

@Dinistro Dinistro left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dropped a bunch of nits

Copy link
Contributor

@Dinistro Dinistro left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for addressing my comments ❤️ . LGTM % some final nits.

@bcardosolopes bcardosolopes merged commit 28934fe into llvm:main May 5, 2025
11 checks passed
GeorgeARM pushed a commit to GeorgeARM/llvm-project that referenced this pull request May 7, 2025
Add one more of these module flags. 

Unlike "CG Profile", LLVM proper does not verify the content of the
metadata, but returns a nullptr in case it's ill-formed (it's up to the
user to take action). This prompted me to implement warning checks,
preventing the importer to consume broken data.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants