Skip to content

Commit def6b16

Browse files
committed
Use enum kind for format
1 parent b4acc60 commit def6b16

File tree

9 files changed

+48
-37
lines changed

9 files changed

+48
-37
lines changed

mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1404,7 +1404,7 @@ def ModuleFlagProfileSummaryAttr
14041404
Describes ProfileSummary gathered data in a module. Example:
14051405
```mlir
14061406
llvm.module_flags [#llvm.mlir.module_flag<error, "ProfileSummary",
1407-
#llvm.profile_summary<format = "InstrProf", total_count = 263646, max_count = 86427,
1407+
#llvm.profile_summary<format = InstrProf, total_count = 263646, max_count = 86427,
14081408
max_internal_count = 86427, max_function_count = 4691,
14091409
num_counts = 3712, num_functions = 796,
14101410
is_partial_profile = 0,
@@ -1415,13 +1415,13 @@ def ModuleFlagProfileSummaryAttr
14151415
>>]
14161416
```
14171417
}];
1418-
let parameters = (
1419-
ins "StringAttr":$format, "uint64_t":$total_count, "uint64_t":$max_count,
1420-
"uint64_t":$max_internal_count, "uint64_t":$max_function_count,
1421-
"uint64_t":$num_counts, "uint64_t":$num_functions,
1422-
OptionalParameter<"std::optional<uint64_t>">:$is_partial_profile,
1423-
OptionalParameter<"FloatAttr">:$partial_profile_ratio,
1424-
ArrayRefParameter<"ModuleFlagProfileSummaryDetailedAttr">:$detailed_summary);
1418+
let parameters = (ins "ProfileSummaryFormatKind":$format,
1419+
"uint64_t":$total_count, "uint64_t":$max_count,
1420+
"uint64_t":$max_internal_count, "uint64_t":$max_function_count,
1421+
"uint64_t":$num_counts, "uint64_t":$num_functions,
1422+
OptionalParameter<"std::optional<uint64_t>">:$is_partial_profile,
1423+
OptionalParameter<"FloatAttr">:$partial_profile_ratio,
1424+
ArrayRefParameter<"ModuleFlagProfileSummaryDetailedAttr">:$detailed_summary);
14251425

14261426
let assemblyFormat = "`<` struct(params) `>`";
14271427
}

mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -823,7 +823,7 @@ def FPExceptionBehaviorAttr : LLVM_EnumAttr<
823823
}
824824

825825
//===----------------------------------------------------------------------===//
826-
// Module Flag Behavior
826+
// Module Flags
827827
//===----------------------------------------------------------------------===//
828828

829829
// These values must match llvm::Module::ModFlagBehavior ones.
@@ -855,6 +855,21 @@ def ModFlagBehaviorAttr : LLVM_EnumAttr<
855855
let cppNamespace = "::mlir::LLVM";
856856
}
857857

858+
def LLVM_ProfileSummaryFormatSampleProfile : I64EnumAttrCase<"SampleProfile",
859+
0>;
860+
def LLVM_ProfileSummaryFormatInstrProf : I64EnumAttrCase<"InstrProf", 1>;
861+
def LLVM_ProfileSummaryFormatCSInstrProf : I64EnumAttrCase<"CSInstrProf", 2>;
862+
863+
def LLVM_ProfileSummaryFormatKind : I64EnumAttr<
864+
"ProfileSummaryFormatKind",
865+
"LLVM ProfileSummary format kinds", [
866+
LLVM_ProfileSummaryFormatSampleProfile,
867+
LLVM_ProfileSummaryFormatInstrProf,
868+
LLVM_ProfileSummaryFormatCSInstrProf,
869+
]> {
870+
let cppNamespace = "::mlir::LLVM";
871+
}
872+
858873
//===----------------------------------------------------------------------===//
859874
// UWTableKind
860875
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -391,15 +391,9 @@ ModuleFlagAttr::verify(function_ref<InFlightDiagnostic()> emitError,
391391
}
392392

393393
if (key == LLVMDialect::getModuleFlagKeyProfileSummaryName()) {
394-
if (auto summaryAttr = dyn_cast<ModuleFlagProfileSummaryAttr>(value)) {
395-
StringRef fmt = summaryAttr.getFormat().getValue();
396-
if (fmt != "SampleProfile" && fmt != "InstrProf" && fmt != "CSInstrProf")
397-
return emitError() << "'ProfileFormat' must be 'SampleProfile', "
398-
"'InstrProf' or 'CSInstrProf'";
399-
} else {
394+
if (!isa<ModuleFlagProfileSummaryAttr>(value))
400395
return emitError() << "'ProfileSummary' key expects a "
401396
"'#llvm.profile_summary' attribute";
402-
}
403397
return success();
404398
}
405399

mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -320,7 +320,8 @@ static llvm::Metadata *convertModuleFlagProfileSummaryAttr(
320320

321321
SmallVector<llvm::Metadata *> fmtNode{
322322
mdb.createString("ProfileFormat"),
323-
mdb.createString(summaryAttr.getFormat().getValue())};
323+
mdb.createString(
324+
stringifyProfileSummaryFormatKind(summaryAttr.getFormat()))};
324325

325326
SmallVector<llvm::Metadata *> vals = {
326327
llvm::MDTuple::get(context, fmtNode),

mlir/lib/Target/LLVMIR/ModuleImport.cpp

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -576,34 +576,32 @@ convertProfileSummaryModuleFlagValue(ModuleOp mlirModule,
576576
return tupleEntry;
577577
};
578578

579-
auto getFormat = [&](const llvm::MDOperand &formatMD) -> StringAttr {
579+
auto getFormat = [&](const llvm::MDOperand &formatMD)
580+
-> std::optional<ProfileSummaryFormatKind> {
580581
auto *tupleEntry = getMDTuple(formatMD);
581582
if (!tupleEntry)
582-
return nullptr;
583+
return std::nullopt;
583584

584585
llvm::MDString *keyMD = dyn_cast<llvm::MDString>(tupleEntry->getOperand(0));
585586
if (!keyMD || keyMD->getString() != "ProfileFormat") {
586587
emitWarning(mlirModule.getLoc())
587588
<< "expected 'ProfileFormat' key: "
588589
<< diagMD(tupleEntry->getOperand(0), llvmModule);
589-
return nullptr;
590+
return std::nullopt;
590591
}
591592

592593
llvm::MDString *valMD = dyn_cast<llvm::MDString>(tupleEntry->getOperand(1));
593-
auto formatAttr = llvm::StringSwitch<std::string>(valMD->getString())
594-
.Case("SampleProfile", "SampleProfile")
595-
.Case("InstrProf", "InstrProf")
596-
.Case("CSInstrProf", "CSInstrProf")
597-
.Default("");
598-
if (formatAttr.empty()) {
594+
std::optional<ProfileSummaryFormatKind> fmtKind =
595+
symbolizeProfileSummaryFormatKind(valMD->getString());
596+
if (!fmtKind) {
599597
emitWarning(mlirModule.getLoc())
600598
<< "expected 'SampleProfile', 'InstrProf' or 'CSInstrProf' values, "
601599
"but found: "
602600
<< diagMD(valMD, llvmModule);
603-
return nullptr;
601+
return std::nullopt;
604602
}
605603

606-
return StringAttr::get(mlirModule->getContext(), formatAttr);
604+
return fmtKind;
607605
};
608606

609607
auto getConstantMD = [&](const llvm::MDOperand &md, StringRef matchKey,
@@ -746,8 +744,9 @@ convertProfileSummaryModuleFlagValue(ModuleOp mlirModule,
746744
// Build ModuleFlagProfileSummaryAttr by sequentially fetching elements in
747745
// a fixed order: format, total count, etc.
748746
SmallVector<Attribute> profileSummary;
749-
StringAttr format = getFormat(mdTuple->getOperand(summayIdx++));
750-
if (!format)
747+
std::optional<ProfileSummaryFormatKind> format =
748+
getFormat(mdTuple->getOperand(summayIdx++));
749+
if (!format.has_value())
751750
return nullptr;
752751

753752
uint64_t totalCount = 0, maxCount = 0, maxInternalCount = 0,
@@ -793,7 +792,7 @@ convertProfileSummaryModuleFlagValue(ModuleOp mlirModule,
793792

794793
// Build the final profile summary attribute.
795794
return ModuleFlagProfileSummaryAttr::get(
796-
mlirModule->getContext(), format, totalCount, maxCount, maxInternalCount,
795+
mlirModule->getContext(), *format, totalCount, maxCount, maxInternalCount,
797796
maxFunctionCount, numCounts, numFunctions, isPartialProfile,
798797
partialProfileRatio ? partialProfileRatio : nullptr, detailedSummary);
799798
}

mlir/test/Dialect/LLVMIR/invalid.mlir

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1807,16 +1807,18 @@ module {
18071807

18081808
// -----
18091809

1810-
// expected-error@below {{'ProfileFormat' must be 'SampleProfile', 'InstrProf' or 'CSInstrProf'}}
18111810
llvm.module_flags [#llvm.mlir.module_flag<error, "ProfileSummary",
1811+
// expected-error@below {{expected one of [SampleProfile, InstrProf, CSInstrProf] for LLVM ProfileSummary format kinds, got: YoloFmt}}
18121812
#llvm.profile_summary<format = "YoloFmt", total_count = 263646, max_count = 86427,
1813+
// expected-error@above {{failed to parse ModuleFlagProfileSummaryAttr parameter 'format' which is to be a `ProfileSummaryFormatKind`}}
18131814
max_internal_count = 86427, max_function_count = 4691,
18141815
num_counts = 3712, num_functions = 796,
18151816
is_partial_profile = 0,
18161817
partial_profile_ratio = 0.000000e+00 : f64,
18171818
detailed_summary =
18181819
<cut_off = 10000, min_count = 86427, num_counts = 1>,
18191820
<cut_off = 100000, min_count = 86427, num_counts = 1>
1821+
// expected-error@below {{failed to parse ModuleFlagAttr parameter}}
18201822
>>]
18211823

18221824
// -----

mlir/test/Dialect/LLVMIR/module-roundtrip.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ module {
1313
#llvm.cgprofile_entry<from = @to, to = @from, count = 222>
1414
]>,
1515
#llvm.mlir.module_flag<error, "ProfileSummary",
16-
#llvm.profile_summary<format = "InstrProf", total_count = 263646, max_count = 86427,
16+
#llvm.profile_summary<format = InstrProf, total_count = 263646, max_count = 86427,
1717
max_internal_count = 86427, max_function_count = 4691,
1818
num_counts = 3712, num_functions = 796,
1919
is_partial_profile = 0,
@@ -37,7 +37,7 @@ module {
3737
// CHECK-SAME: #llvm.cgprofile_entry<from = @to, to = @from, count = 222>
3838
// CHECK-SAME: ]>,
3939
// CHECK-SAME: #llvm.mlir.module_flag<error, "ProfileSummary",
40-
// CHECK-SAME: #llvm.profile_summary<format = "InstrProf", total_count = 263646, max_count = 86427,
40+
// CHECK-SAME: #llvm.profile_summary<format = InstrProf, total_count = 263646, max_count = 86427,
4141
// CHECK-SAME: max_internal_count = 86427, max_function_count = 4691,
4242
// CHECK-SAME: num_counts = 3712, num_functions = 796,
4343
// CHECK-SAME: is_partial_profile = 0,

mlir/test/Target/LLVMIR/Import/module-flags.ll

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ declare void @to()
5858
!31887 = !{i32 100000, i64 86427, i32 1}
5959

6060
; CHECK: llvm.module_flags [#llvm.mlir.module_flag<error, "ProfileSummary",
61-
; CHECK-SAME: #llvm.profile_summary<format = "InstrProf", total_count = 263646,
61+
; CHECK-SAME: #llvm.profile_summary<format = InstrProf, total_count = 263646,
6262
; CHECK-SAME: max_count = 86427, max_internal_count = 86427, max_function_count = 4691,
6363
; CHECK-SAME: num_counts = 3712, num_functions = 796, is_partial_profile = 0,
6464
; CHECK-SAME: partial_profile_ratio = 0.000000e+00 : f64,
@@ -88,7 +88,7 @@ declare void @to()
8888
!41887 = !{i32 100000, i64 86427, i32 1}
8989

9090
; CHECK: llvm.module_flags [#llvm.mlir.module_flag<error, "ProfileSummary",
91-
; CHECK-SAME: #llvm.profile_summary<format = "InstrProf", total_count = 263646,
91+
; CHECK-SAME: #llvm.profile_summary<format = InstrProf, total_count = 263646,
9292
; CHECK-SAME: max_count = 86427, max_internal_count = 86427, max_function_count = 4691,
9393
; CHECK-SAME: num_counts = 3712, num_functions = 796,
9494
; CHECK-SAME: detailed_summary =

mlir/test/Target/LLVMIR/llvmir.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2884,7 +2884,7 @@ llvm.func @to()
28842884
// -----
28852885

28862886
llvm.module_flags [#llvm.mlir.module_flag<error, "ProfileSummary",
2887-
#llvm.profile_summary<format = "InstrProf", total_count = 263646, max_count = 86427,
2887+
#llvm.profile_summary<format = InstrProf, total_count = 263646, max_count = 86427,
28882888
max_internal_count = 86427, max_function_count = 4691,
28892889
num_counts = 3712, num_functions = 796,
28902890
is_partial_profile = 0,

0 commit comments

Comments
 (0)