Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -1378,6 +1378,54 @@ def ModuleFlagCGProfileEntryAttr
let assemblyFormat = "`<` struct(params) `>`";
}

def ModuleFlagProfileSummaryDetailedAttr
: LLVM_Attr<"ModuleFlagProfileSummaryDetailed", "profile_summary_detailed"> {
let summary = "ProfileSummary detailed information";
let description = [{
Contains detailed information pertinent to "ProfileSummary" attribute.
A `#llvm.profile_summary` may contain several of it.
```mlir
llvm.module_flags [ ...
detailed_summary =
<cut_off = 10000, min_count = 86427, num_counts = 1>,
<cut_off = 100000, min_count = 86427, num_counts = 1>
```
}];
let parameters = (ins "uint32_t":$cut_off,
"uint64_t":$min_count,
"uint32_t":$num_counts);
let assemblyFormat = "`<` struct(params) `>`";
}

def ModuleFlagProfileSummaryAttr
: LLVM_Attr<"ModuleFlagProfileSummary", "profile_summary"> {
let summary = "ProfileSummary module flag";
let description = [{
Describes ProfileSummary gathered data in a module. Example:
```mlir
llvm.module_flags [#llvm.mlir.module_flag<error, "ProfileSummary",
#llvm.profile_summary<format = InstrProf, total_count = 263646, max_count = 86427,
max_internal_count = 86427, max_function_count = 4691,
num_counts = 3712, num_functions = 796,
is_partial_profile = 0,
partial_profile_ratio = 0.000000e+00 : f64,
detailed_summary =
<cut_off = 10000, min_count = 86427, num_counts = 1>,
<cut_off = 100000, min_count = 86427, num_counts = 1>
>>]
```
}];
let parameters = (ins "ProfileSummaryFormatKind":$format,
"uint64_t":$total_count, "uint64_t":$max_count,
"uint64_t":$max_internal_count, "uint64_t":$max_function_count,
"uint64_t":$num_counts, "uint64_t":$num_functions,
OptionalParameter<"std::optional<uint64_t>">:$is_partial_profile,
OptionalParameter<"FloatAttr">:$partial_profile_ratio,
ArrayRefParameter<"ModuleFlagProfileSummaryDetailedAttr">:$detailed_summary);

let assemblyFormat = "`<` struct(params) `>`";
}

//===----------------------------------------------------------------------===//
// LLVM_DependentLibrariesAttr
//===----------------------------------------------------------------------===//
Expand Down
3 changes: 3 additions & 0 deletions mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.td
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,9 @@ def LLVM_Dialect : Dialect {
static StringRef getModuleFlagKeyCGProfileName() {
return "CG Profile";
}
static StringRef getModuleFlagKeyProfileSummaryName() {
return "ProfileSummary";
}

/// Returns `true` if the given type is compatible with the LLVM dialect.
static bool isCompatibleType(Type);
Expand Down
17 changes: 16 additions & 1 deletion mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td
Original file line number Diff line number Diff line change
Expand Up @@ -823,7 +823,7 @@ def FPExceptionBehaviorAttr : LLVM_EnumAttr<
}

//===----------------------------------------------------------------------===//
// Module Flag Behavior
// Module Flags
//===----------------------------------------------------------------------===//

// These values must match llvm::Module::ModFlagBehavior ones.
Expand Down Expand Up @@ -855,6 +855,21 @@ def ModFlagBehaviorAttr : LLVM_EnumAttr<
let cppNamespace = "::mlir::LLVM";
}

def LLVM_ProfileSummaryFormatSampleProfile : I64EnumAttrCase<"SampleProfile",
0>;
def LLVM_ProfileSummaryFormatInstrProf : I64EnumAttrCase<"InstrProf", 1>;
def LLVM_ProfileSummaryFormatCSInstrProf : I64EnumAttrCase<"CSInstrProf", 2>;

def LLVM_ProfileSummaryFormatKind : I64EnumAttr<
"ProfileSummaryFormatKind",
"LLVM ProfileSummary format kinds", [
LLVM_ProfileSummaryFormatSampleProfile,
LLVM_ProfileSummaryFormatInstrProf,
LLVM_ProfileSummaryFormatCSInstrProf,
]> {
let cppNamespace = "::mlir::LLVM";
}

//===----------------------------------------------------------------------===//
// UWTableKind
//===----------------------------------------------------------------------===//
Expand Down
7 changes: 7 additions & 0 deletions mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,13 @@ ModuleFlagAttr::verify(function_ref<InFlightDiagnostic()> emitError,
return success();
}

if (key == LLVMDialect::getModuleFlagKeyProfileSummaryName()) {
if (!isa<ModuleFlagProfileSummaryAttr>(value))
return emitError() << "'ProfileSummary' key expects a "
"'#llvm.profile_summary' attribute";
return success();
}

if (isa<IntegerAttr, StringAttr>(value))
return success();

Expand Down
67 changes: 67 additions & 0 deletions mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -300,9 +300,71 @@ convertModuleFlagValue(StringRef key, ArrayAttr arrayAttr,
}
return llvm::MDTuple::getDistinct(context, nodes);
}

return nullptr;
}

static llvm::Metadata *convertModuleFlagProfileSummaryAttr(
StringRef key, ModuleFlagProfileSummaryAttr summaryAttr,
llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation) {
llvm::LLVMContext &context = builder.getContext();
llvm::MDBuilder mdb(context);
SmallVector<llvm::Metadata *> summaryNodes;

auto getIntTuple = [&](StringRef key, uint64_t val) -> llvm::MDTuple * {
SmallVector<llvm::Metadata *> tupleNodes{
mdb.createString(key), mdb.createConstant(llvm::ConstantInt::get(
llvm::Type::getInt64Ty(context), val))};
return llvm::MDTuple::get(context, tupleNodes);
};

SmallVector<llvm::Metadata *> fmtNode{
mdb.createString("ProfileFormat"),
mdb.createString(
stringifyProfileSummaryFormatKind(summaryAttr.getFormat()))};

SmallVector<llvm::Metadata *> vals = {
llvm::MDTuple::get(context, fmtNode),
getIntTuple("TotalCount", summaryAttr.getTotalCount()),
getIntTuple("MaxCount", summaryAttr.getMaxCount()),
getIntTuple("MaxInternalCount", summaryAttr.getMaxInternalCount()),
getIntTuple("MaxFunctionCount", summaryAttr.getMaxFunctionCount()),
getIntTuple("NumCounts", summaryAttr.getNumCounts()),
getIntTuple("NumFunctions", summaryAttr.getNumFunctions()),
};

if (summaryAttr.getIsPartialProfile())
vals.push_back(
getIntTuple("IsPartialProfile", *summaryAttr.getIsPartialProfile()));

if (summaryAttr.getPartialProfileRatio()) {
SmallVector<llvm::Metadata *> tupleNodes{
mdb.createString("PartialProfileRatio"),
mdb.createConstant(llvm::ConstantFP::get(
llvm::Type::getDoubleTy(context),
summaryAttr.getPartialProfileRatio().getValue()))};
vals.push_back(llvm::MDTuple::get(context, tupleNodes));
}

SmallVector<llvm::Metadata *> detailedEntries;
for (auto detailedEntry : summaryAttr.getDetailedSummary()) {
SmallVector<llvm::Metadata *> tupleNodes{
mdb.createConstant(llvm::ConstantInt::get(
llvm::Type::getInt64Ty(context), detailedEntry.getCutOff())),
mdb.createConstant(llvm::ConstantInt::get(
llvm::Type::getInt64Ty(context), detailedEntry.getMinCount())),
mdb.createConstant(llvm::ConstantInt::get(
llvm::Type::getInt64Ty(context), detailedEntry.getNumCounts()))};
detailedEntries.push_back(llvm::MDTuple::get(context, tupleNodes));
}
SmallVector<llvm::Metadata *> detailedSummary{
mdb.createString("DetailedSummary"),
llvm::MDTuple::get(context, detailedEntries)};
vals.push_back(llvm::MDTuple::get(context, detailedSummary));

return llvm::MDNode::get(context, vals);
}

static void convertModuleFlagsOp(ArrayAttr flags, llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation) {
llvm::Module *llvmModule = moduleTranslation.getLLVMModule();
Expand All @@ -323,6 +385,11 @@ static void convertModuleFlagsOp(ArrayAttr flags, llvm::IRBuilderBase &builder,
arrayAttr, builder,
moduleTranslation);
})
.Case<ModuleFlagProfileSummaryAttr>([&](auto summaryAttr) {
return convertModuleFlagProfileSummaryAttr(
flagAttr.getKey().getValue(), summaryAttr, builder,
moduleTranslation);
})
.Default([](auto) { return nullptr; });

assert(valueMetadata && "expected valid metadata");
Expand Down
Loading