32
32
#include " type_graph/DrgnParser.h"
33
33
#include " type_graph/EnforceCompatibility.h"
34
34
#include " type_graph/Flattener.h"
35
+ #include " type_graph/KeyCapture.h"
35
36
#include " type_graph/NameGen.h"
36
37
#include " type_graph/Prune.h"
37
38
#include " type_graph/RemoveMembers.h"
@@ -46,13 +47,15 @@ namespace oi::detail {
46
47
using type_graph::AddChildren;
47
48
using type_graph::AddPadding;
48
49
using type_graph::AlignmentCalc;
50
+ using type_graph::CaptureKeys;
49
51
using type_graph::Class;
50
52
using type_graph::Container;
51
53
using type_graph::DrgnParser;
52
54
using type_graph::DrgnParserOptions;
53
55
using type_graph::EnforceCompatibility;
54
56
using type_graph::Enum;
55
57
using type_graph::Flattener;
58
+ using type_graph::KeyCapture;
56
59
using type_graph::Member;
57
60
using type_graph::NameGen;
58
61
using type_graph::Primitive;
@@ -96,12 +99,17 @@ void defineMacros(std::string& code) {
96
99
}
97
100
}
98
101
99
- void defineArray (std::string& code) {
102
+ void defineInternalTypes (std::string& code) {
100
103
code += R"(
101
104
template<typename T, int N>
102
105
struct OIArray {
103
106
T vals[N];
104
107
};
108
+
109
+ // Just here to give a different type name to containers whose keys we'll capture
110
+ template <typename T>
111
+ struct OICaptureKeys : public T {
112
+ };
105
113
)" ;
106
114
}
107
115
@@ -831,6 +839,8 @@ void genContainerTypeHandler(FeatureSet features,
831
839
return ;
832
840
}
833
841
842
+ code += c.codegen .extra ;
843
+
834
844
// TODO: Move this check into the ContainerInfo parsing once always enabled.
835
845
const auto & func = c.codegen .traversalFunc ;
836
846
const auto & processors = c.codegen .processors ;
@@ -859,6 +869,10 @@ void genContainerTypeHandler(FeatureSet features,
859
869
if (!templateParams.empty ())
860
870
containerWithTypes += ' >' ;
861
871
872
+ if (c.captureKeys ) {
873
+ containerWithTypes = " OICaptureKeys<" + containerWithTypes + " >" ;
874
+ }
875
+
862
876
code += " template <typename DB" ;
863
877
types = 0 , values = 0 ;
864
878
for (const auto & p : templateParams) {
@@ -875,6 +889,11 @@ void genContainerTypeHandler(FeatureSet features,
875
889
code += containerWithTypes;
876
890
code += " > {\n " ;
877
891
892
+ if (c.captureKeys ) {
893
+ code += " static constexpr bool captureKeys = true;\n " ;
894
+ } else {
895
+ code += " static constexpr bool captureKeys = false;\n " ;
896
+ }
878
897
code += " using type = " ;
879
898
if (processors.empty ()) {
880
899
code += " types::st::Unit<DB>" ;
@@ -931,9 +950,80 @@ void genContainerTypeHandler(FeatureSet features,
931
950
code += " };\n\n " ;
932
951
}
933
952
953
+ void addCaptureKeySupport (std::string& code) {
954
+ code += R"(
955
+ template <typename DB, typename T>
956
+ class CaptureKeyHandler {
957
+ public:
958
+ using type = types::st::Sum<DB, types::st::VarInt<DB>, types::st::VarInt<DB>>;
959
+
960
+ static auto captureKey(const T& key, auto returnArg) {
961
+ // Save scalars keys directly, otherwise save pointers for complex types
962
+ if constexpr (std::is_scalar_v<T>) {
963
+ return returnArg.template write<0>().write(static_cast<uint64_t>(key));
964
+ }
965
+ return returnArg.template write<1>().write(reinterpret_cast<uintptr_t>(&key));
966
+ }
967
+ };
968
+
969
+ template <bool CaptureKeys, typename DB, typename T>
970
+ auto maybeCaptureKey(const T& key, auto returnArg) {
971
+ if constexpr (CaptureKeys) {
972
+ return returnArg.delegate([&key](auto ret) {
973
+ return CaptureKeyHandler<DB, T>::captureKey(key, ret);
974
+ });
975
+ } else {
976
+ return returnArg;
977
+ }
978
+ }
979
+
980
+ template <typename DB, typename T>
981
+ static constexpr inst::ProcessorInst CaptureKeysProcessor{
982
+ CaptureKeyHandler<DB, T>::type::describe,
983
+ [](result::Element& el, std::function<void(inst::Inst)> stack_ins, ParsedData d) {
984
+ if constexpr (std::is_same_v<
985
+ typename CaptureKeyHandler<DB, T>::type,
986
+ types::st::List<DB, types::st::VarInt<DB>>>) {
987
+ // String
988
+ auto& str = el.data.emplace<std::string>();
989
+ auto list = std::get<ParsedData::List>(d.val);
990
+ size_t strlen = list.length;
991
+ for (size_t i = 0; i < strlen; i++) {
992
+ auto value = list.values().val;
993
+ auto c = std::get<ParsedData::VarInt>(value).value;
994
+ str.push_back(c);
995
+ }
996
+ } else {
997
+ auto sum = std::get<ParsedData::Sum>(d.val);
998
+ if (sum.index == 0) {
999
+ el.data = oi::result::Element::Scalar{std::get<ParsedData::VarInt>(sum.value().val).value};
1000
+ } else {
1001
+ el.data = oi::result::Element::Pointer{std::get<ParsedData::VarInt>(sum.value().val).value};
1002
+ }
1003
+ }
1004
+ }
1005
+ };
1006
+
1007
+ template <bool CaptureKeys, typename DB, typename T>
1008
+ static constexpr auto maybeCaptureKeysProcessor() {
1009
+ if constexpr (CaptureKeys) {
1010
+ return std::array<inst::ProcessorInst, 1>{
1011
+ CaptureKeysProcessor<DB, T>,
1012
+ };
1013
+ }
1014
+ else {
1015
+ return std::array<inst::ProcessorInst, 0>{};
1016
+ }
1017
+ }
1018
+ )" ;
1019
+ }
1020
+
934
1021
void addStandardTypeHandlers (TypeGraph& typeGraph,
935
1022
FeatureSet features,
936
1023
std::string& code) {
1024
+ if (features[Feature::TreeBuilderV2])
1025
+ addCaptureKeySupport (code);
1026
+
937
1027
// Provide a wrapper function, getSizeType, to infer T instead of having to
938
1028
// explicitly specify it with TypeHandler<DB, T>::getSizeType every time.
939
1029
code += R"(
@@ -983,6 +1073,10 @@ void CodeGen::addTypeHandlers(const TypeGraph& typeGraph, std::string& code) {
983
1073
} else if (const auto * con = dynamic_cast <const Container*>(&t)) {
984
1074
genContainerTypeHandler (config_.features , definedContainers_,
985
1075
con->containerInfo_ , con->templateParams , code);
1076
+ } else if (const auto * cap = dynamic_cast <const CaptureKeys*>(&t)) {
1077
+ genContainerTypeHandler (config_.features , definedContainers_,
1078
+ cap->containerInfo (),
1079
+ cap->container ().templateParams , code);
986
1080
}
987
1081
}
988
1082
}
@@ -1061,10 +1155,13 @@ void CodeGen::transform(TypeGraph& typeGraph) {
1061
1155
// Calculate alignment before removing members, as those members may have an
1062
1156
// influence on the class' overall alignment.
1063
1157
pm.addPass (AlignmentCalc::createPass ());
1158
+
1064
1159
pm.addPass (RemoveMembers::createPass (config_.membersToStub ));
1065
- if (!config_.features [Feature::TreeBuilderV2]) {
1160
+ if (!config_.features [Feature::TreeBuilderV2])
1066
1161
pm.addPass (EnforceCompatibility::createPass ());
1067
- }
1162
+ if (config_.features [Feature::TreeBuilderV2] &&
1163
+ !config_.keysToCapture .empty ())
1164
+ pm.addPass (KeyCapture::createPass (config_.keysToCapture , containerInfos_));
1068
1165
1069
1166
// Add padding to fill in the gaps of removed members and ensure their
1070
1167
// alignments
@@ -1094,7 +1191,7 @@ void CodeGen::generate(
1094
1191
defineMacros (code);
1095
1192
}
1096
1193
addIncludes (typeGraph, config_.features , code);
1097
- defineArray (code);
1194
+ defineInternalTypes (code);
1098
1195
FuncGen::DefineJitLog (code, config_.features );
1099
1196
1100
1197
if (config_.features [Feature::TypedDataSegment]) {
0 commit comments