Skip to content

Commit 4afa2ff

Browse files
committed
CodeGen: Add support for key capture
1 parent 3446339 commit 4afa2ff

File tree

2 files changed

+118
-4
lines changed

2 files changed

+118
-4
lines changed

oi/CodeGen.cpp

Lines changed: 101 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
#include "type_graph/DrgnParser.h"
3333
#include "type_graph/EnforceCompatibility.h"
3434
#include "type_graph/Flattener.h"
35+
#include "type_graph/KeyCapture.h"
3536
#include "type_graph/NameGen.h"
3637
#include "type_graph/Prune.h"
3738
#include "type_graph/RemoveMembers.h"
@@ -46,13 +47,15 @@ namespace oi::detail {
4647
using type_graph::AddChildren;
4748
using type_graph::AddPadding;
4849
using type_graph::AlignmentCalc;
50+
using type_graph::CaptureKeys;
4951
using type_graph::Class;
5052
using type_graph::Container;
5153
using type_graph::DrgnParser;
5254
using type_graph::DrgnParserOptions;
5355
using type_graph::EnforceCompatibility;
5456
using type_graph::Enum;
5557
using type_graph::Flattener;
58+
using type_graph::KeyCapture;
5659
using type_graph::Member;
5760
using type_graph::NameGen;
5861
using type_graph::Primitive;
@@ -96,12 +99,17 @@ void defineMacros(std::string& code) {
9699
}
97100
}
98101

99-
void defineArray(std::string& code) {
102+
void defineInternalTypes(std::string& code) {
100103
code += R"(
101104
template<typename T, int N>
102105
struct OIArray {
103106
T vals[N];
104107
};
108+
109+
// Just here to give a different type name to containers whose keys we'll capture
110+
template <typename T>
111+
struct OICaptureKeys : public T {
112+
};
105113
)";
106114
}
107115

@@ -831,6 +839,8 @@ void genContainerTypeHandler(FeatureSet features,
831839
return;
832840
}
833841

842+
code += c.codegen.extra;
843+
834844
// TODO: Move this check into the ContainerInfo parsing once always enabled.
835845
const auto& func = c.codegen.traversalFunc;
836846
const auto& processors = c.codegen.processors;
@@ -859,6 +869,10 @@ void genContainerTypeHandler(FeatureSet features,
859869
if (!templateParams.empty())
860870
containerWithTypes += '>';
861871

872+
if (c.captureKeys) {
873+
containerWithTypes = "OICaptureKeys<" + containerWithTypes + ">";
874+
}
875+
862876
code += "template <typename DB";
863877
types = 0, values = 0;
864878
for (const auto& p : templateParams) {
@@ -875,6 +889,11 @@ void genContainerTypeHandler(FeatureSet features,
875889
code += containerWithTypes;
876890
code += "> {\n";
877891

892+
if (c.captureKeys) {
893+
code += " static constexpr bool captureKeys = true;\n";
894+
} else {
895+
code += " static constexpr bool captureKeys = false;\n";
896+
}
878897
code += " using type = ";
879898
if (processors.empty()) {
880899
code += "types::st::Unit<DB>";
@@ -931,9 +950,80 @@ void genContainerTypeHandler(FeatureSet features,
931950
code += "};\n\n";
932951
}
933952

953+
void addCaptureKeySupport(std::string& code) {
954+
code += R"(
955+
template <typename DB, typename T>
956+
class CaptureKeyHandler {
957+
public:
958+
using type = types::st::Sum<DB, types::st::VarInt<DB>, types::st::VarInt<DB>>;
959+
960+
static auto captureKey(const T& key, auto returnArg) {
961+
// Save scalars keys directly, otherwise save pointers for complex types
962+
if constexpr (std::is_scalar_v<T>) {
963+
return returnArg.template write<0>().write(static_cast<uint64_t>(key));
964+
}
965+
return returnArg.template write<1>().write(reinterpret_cast<uintptr_t>(&key));
966+
}
967+
};
968+
969+
template <bool CaptureKeys, typename DB, typename T>
970+
auto maybeCaptureKey(const T& key, auto returnArg) {
971+
if constexpr (CaptureKeys) {
972+
return returnArg.delegate([&key](auto ret) {
973+
return CaptureKeyHandler<DB, T>::captureKey(key, ret);
974+
});
975+
} else {
976+
return returnArg;
977+
}
978+
}
979+
980+
template <typename DB, typename T>
981+
static constexpr inst::ProcessorInst CaptureKeysProcessor{
982+
CaptureKeyHandler<DB, T>::type::describe,
983+
[](result::Element& el, std::function<void(inst::Inst)> stack_ins, ParsedData d) {
984+
if constexpr (std::is_same_v<
985+
typename CaptureKeyHandler<DB, T>::type,
986+
types::st::List<DB, types::st::VarInt<DB>>>) {
987+
// String
988+
auto& str = el.data.emplace<std::string>();
989+
auto list = std::get<ParsedData::List>(d.val);
990+
size_t strlen = list.length;
991+
for (size_t i = 0; i < strlen; i++) {
992+
auto value = list.values().val;
993+
auto c = std::get<ParsedData::VarInt>(value).value;
994+
str.push_back(c);
995+
}
996+
} else {
997+
auto sum = std::get<ParsedData::Sum>(d.val);
998+
if (sum.index == 0) {
999+
el.data = oi::result::Element::Scalar{std::get<ParsedData::VarInt>(sum.value().val).value};
1000+
} else {
1001+
el.data = oi::result::Element::Pointer{std::get<ParsedData::VarInt>(sum.value().val).value};
1002+
}
1003+
}
1004+
}
1005+
};
1006+
1007+
template <bool CaptureKeys, typename DB, typename T>
1008+
static constexpr auto maybeCaptureKeysProcessor() {
1009+
if constexpr (CaptureKeys) {
1010+
return std::array<inst::ProcessorInst, 1>{
1011+
CaptureKeysProcessor<DB, T>,
1012+
};
1013+
}
1014+
else {
1015+
return std::array<inst::ProcessorInst, 0>{};
1016+
}
1017+
}
1018+
)";
1019+
}
1020+
9341021
void addStandardTypeHandlers(TypeGraph& typeGraph,
9351022
FeatureSet features,
9361023
std::string& code) {
1024+
if (features[Feature::TreeBuilderV2])
1025+
addCaptureKeySupport(code);
1026+
9371027
// Provide a wrapper function, getSizeType, to infer T instead of having to
9381028
// explicitly specify it with TypeHandler<DB, T>::getSizeType every time.
9391029
code += R"(
@@ -983,6 +1073,10 @@ void CodeGen::addTypeHandlers(const TypeGraph& typeGraph, std::string& code) {
9831073
} else if (const auto* con = dynamic_cast<const Container*>(&t)) {
9841074
genContainerTypeHandler(config_.features, definedContainers_,
9851075
con->containerInfo_, con->templateParams, code);
1076+
} else if (const auto* cap = dynamic_cast<const CaptureKeys*>(&t)) {
1077+
genContainerTypeHandler(config_.features, definedContainers_,
1078+
cap->containerInfo(),
1079+
cap->container().templateParams, code);
9861080
}
9871081
}
9881082
}
@@ -1061,10 +1155,13 @@ void CodeGen::transform(TypeGraph& typeGraph) {
10611155
// Calculate alignment before removing members, as those members may have an
10621156
// influence on the class' overall alignment.
10631157
pm.addPass(AlignmentCalc::createPass());
1158+
10641159
pm.addPass(RemoveMembers::createPass(config_.membersToStub));
1065-
if (!config_.features[Feature::TreeBuilderV2]) {
1160+
if (!config_.features[Feature::TreeBuilderV2])
10661161
pm.addPass(EnforceCompatibility::createPass());
1067-
}
1162+
if (config_.features[Feature::TreeBuilderV2] &&
1163+
!config_.keysToCapture.empty())
1164+
pm.addPass(KeyCapture::createPass(config_.keysToCapture, containerInfos_));
10681165

10691166
// Add padding to fill in the gaps of removed members and ensure their
10701167
// alignments
@@ -1094,7 +1191,7 @@ void CodeGen::generate(
10941191
defineMacros(code);
10951192
}
10961193
addIncludes(typeGraph, config_.features, code);
1097-
defineArray(code);
1194+
defineInternalTypes(code);
10981195
FuncGen::DefineJitLog(code, config_.features);
10991196

11001197
if (config_.features[Feature::TypedDataSegment]) {

types/cxx11_string_type.toml

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,23 @@ struct TypeHandler<DB, %1% <T0>> {
5454
};
5555
"""
5656

57+
extra = """
58+
template <typename DB, typename CharT, typename Traits, typename Allocator>
59+
class CaptureKeyHandler<DB, std::__cxx11::basic_string<CharT, Traits, Allocator>> {
60+
public:
61+
// List of characters
62+
using type = types::st::List<DB, types::st::VarInt<DB>>;
63+
64+
static auto captureKey(const std::__cxx11::basic_string<CharT, Traits, Allocator>& key, auto returnArg) {
65+
auto tail = returnArg.write(key.size());
66+
for (auto c : key) {
67+
tail = returnArg.write((uintptr_t)c);
68+
}
69+
return tail.finish();
70+
}
71+
};
72+
"""
73+
5774
traversal_func = """
5875
bool sso = ((uintptr_t)container.data() <
5976
(uintptr_t)(&container + sizeof(std::__cxx11::basic_string<T0>))) &&

0 commit comments

Comments
 (0)