Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -1378,6 +1378,54 @@ def ModuleFlagCGProfileEntryAttr
let assemblyFormat = "`<` struct(params) `>`";
}

def ModuleFlagProfileSummaryDetailedAttr
: LLVM_Attr<"ModuleFlagProfileSummaryDetailed", "profile_summary_detailed"> {
let summary = "ProfileSummary detailed information";
let description = [{
Contains detailed information pertinent to "ProfileSummary" attribute.
A `#llvm.profile_summary` may contain several of it.
```mlir
llvm.module_flags [ ...
detailed_summary = [
#llvm.profile_summary_detailed<cut_off = 10000, min_count = 86427, num_counts = 1>,
#llvm.profile_summary_detailed<cut_off = 100000, min_count = 86427, num_counts = 1>
```
}];
let parameters = (ins "uint32_t":$cut_off,
"uint64_t":$min_count,
"uint32_t":$num_counts);
let assemblyFormat = "`<` struct(params) `>`";
}

def ModuleFlagProfileSummaryAttr
: LLVM_Attr<"ModuleFlagProfileSummary", "profile_summary"> {
let summary = "ProfileSummary module flag";
let description = [{
Describes ProfileSummary gathered data in a module. Example:
```mlir
llvm.module_flags [#llvm.mlir.module_flag<error, "ProfileSummary",
#llvm.profile_summary<format = "InstrProf", total_count = 263646, max_count = 86427,
max_internal_count = 86427, max_function_count = 4691,
num_counts = 3712, num_functions = 796,
is_partial_profile = 0 : i64,
partial_profile_ratio = 0.000000e+00 : f64,
detailed_summary = [
#llvm.profile_summary_detailed<cut_off = 10000, min_count = 86427, num_counts = 1>,
#llvm.profile_summary_detailed<cut_off = 100000, min_count = 86427, num_counts = 1>
]>>]
```
}];
let parameters = (
ins "StringAttr":$format, "uint64_t":$total_count, "uint64_t":$max_count,
"uint64_t":$max_internal_count, "uint64_t":$max_function_count,
"uint64_t":$num_counts, "uint64_t":$num_functions,
OptionalParameter<"IntegerAttr">:$is_partial_profile,
OptionalParameter<"FloatAttr">:$partial_profile_ratio,
"ArrayAttr":$detailed_summary);

let assemblyFormat = "`<` struct(params) `>`";
}

//===----------------------------------------------------------------------===//
// LLVM_DependentLibrariesAttr
//===----------------------------------------------------------------------===//
Expand Down
3 changes: 3 additions & 0 deletions mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.td
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,9 @@ def LLVM_Dialect : Dialect {
static StringRef getModuleFlagKeyCGProfileName() {
return "CG Profile";
}
static StringRef getModuleFlagKeyProfileSummaryName() {
return "ProfileSummary";
}

/// Returns `true` if the given type is compatible with the LLVM dialect.
static bool isCompatibleType(Type);
Expand Down
13 changes: 13 additions & 0 deletions mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,19 @@ ModuleFlagAttr::verify(function_ref<InFlightDiagnostic()> emitError,
return success();
}

if (key == LLVMDialect::getModuleFlagKeyProfileSummaryName()) {
if (auto summaryAttr = dyn_cast<ModuleFlagProfileSummaryAttr>(value)) {
StringRef fmt = summaryAttr.getFormat().getValue();
if (fmt != "SampleProfile" && fmt != "InstrProf" && fmt != "CSInstrProf")
return emitError() << "'ProfileFormat' must be 'SampleProfile', "
"'InstrProf' or 'CSInstrProf'";
} else {
return emitError() << "'ProfileSummary' key expects a "
"'#llvm.profile_summary' attribute";
}
return success();
}

if (isa<IntegerAttr, StringAttr>(value))
return success();

Expand Down
68 changes: 68 additions & 0 deletions mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -300,9 +300,72 @@ convertModuleFlagValue(StringRef key, ArrayAttr arrayAttr,
}
return llvm::MDTuple::getDistinct(context, nodes);
}

return nullptr;
}

static llvm::Metadata *convertModuleFlagProfileSummaryAttr(
StringRef key, ModuleFlagProfileSummaryAttr summaryAttr,
llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation) {
llvm::LLVMContext &context = builder.getContext();
llvm::MDBuilder mdb(context);
SmallVector<llvm::Metadata *> summaryNodes;

auto getIntTuple = [&](StringRef key, uint64_t val) -> llvm::MDTuple * {
SmallVector<llvm::Metadata *> tupleNodes{
mdb.createString(key), mdb.createConstant(llvm::ConstantInt::get(
llvm::Type::getInt64Ty(context), val))};
return llvm::MDTuple::get(context, tupleNodes);
};

SmallVector<llvm::Metadata *> fmtNode{
mdb.createString("ProfileFormat"),
mdb.createString(summaryAttr.getFormat().getValue())};

SmallVector<llvm::Metadata *> vals = {
llvm::MDTuple::get(context, fmtNode),
getIntTuple("TotalCount", summaryAttr.getTotalCount()),
getIntTuple("MaxCount", summaryAttr.getMaxCount()),
getIntTuple("MaxInternalCount", summaryAttr.getMaxInternalCount()),
getIntTuple("MaxFunctionCount", summaryAttr.getMaxFunctionCount()),
getIntTuple("NumCounts", summaryAttr.getNumCounts()),
getIntTuple("NumFunctions", summaryAttr.getNumFunctions()),
};

if (summaryAttr.getIsPartialProfile())
vals.push_back(getIntTuple("IsPartialProfile",
summaryAttr.getIsPartialProfile().getUInt()));

if (summaryAttr.getPartialProfileRatio()) {
SmallVector<llvm::Metadata *> tupleNodes{
mdb.createString("PartialProfileRatio"),
mdb.createConstant(llvm::ConstantFP::get(
llvm::Type::getDoubleTy(context),
summaryAttr.getPartialProfileRatio().getValue()))};
vals.push_back(llvm::MDTuple::get(context, tupleNodes));
}

SmallVector<llvm::Metadata *> detailedEntries;
for (auto detailedEntry :
summaryAttr.getDetailedSummary()
.getAsRange<ModuleFlagProfileSummaryDetailedAttr>()) {
SmallVector<llvm::Metadata *> tupleNodes{
mdb.createConstant(llvm::ConstantInt::get(
llvm::Type::getInt64Ty(context), detailedEntry.getCutOff())),
mdb.createConstant(llvm::ConstantInt::get(
llvm::Type::getInt64Ty(context), detailedEntry.getMinCount())),
mdb.createConstant(llvm::ConstantInt::get(
llvm::Type::getInt64Ty(context), detailedEntry.getNumCounts()))};
detailedEntries.push_back(llvm::MDTuple::get(context, tupleNodes));
}
SmallVector<llvm::Metadata *> detailedSummary{
mdb.createString("DetailedSummary"),
llvm::MDTuple::get(context, detailedEntries)};
vals.push_back(llvm::MDTuple::get(context, detailedSummary));

return llvm::MDNode::get(context, vals);
}

static void convertModuleFlagsOp(ArrayAttr flags, llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation) {
llvm::Module *llvmModule = moduleTranslation.getLLVMModule();
Expand All @@ -323,6 +386,11 @@ static void convertModuleFlagsOp(ArrayAttr flags, llvm::IRBuilderBase &builder,
arrayAttr, builder,
moduleTranslation);
})
.Case<ModuleFlagProfileSummaryAttr>([&](auto summaryAttr) {
return convertModuleFlagProfileSummaryAttr(
flagAttr.getKey().getValue(), summaryAttr, builder,
moduleTranslation);
})
.Default([](auto) { return nullptr; });

assert(valueMetadata && "expected valid metadata");
Expand Down
Loading
Loading