Skip to content

Commit 3446339

Browse files
committed
TypeGraph: Add KeyCapture pass
1 parent 0ae08ad commit 3446339

File tree

6 files changed

+321
-0
lines changed

6 files changed

+321
-0
lines changed

oi/type_graph/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ add_library(type_graph
55
DrgnParser.cpp
66
EnforceCompatibility.cpp
77
Flattener.cpp
8+
KeyCapture.cpp
89
NameGen.cpp
910
PassManager.cpp
1011
Printer.cpp

oi/type_graph/KeyCapture.cpp

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
#include "KeyCapture.h"
17+
18+
#include "TypeGraph.h"
19+
20+
namespace oi::detail::type_graph {
21+
22+
Pass KeyCapture::createPass(
23+
const std::vector<OICodeGen::Config::KeyToCapture>& keysToCapture,
24+
std::vector<std::unique_ptr<ContainerInfo>>& containerInfos) {
25+
auto fn = [&keysToCapture, &containerInfos](TypeGraph& typeGraph,
26+
NodeTracker& tracker) {
27+
KeyCapture pass{tracker, typeGraph, keysToCapture, containerInfos};
28+
pass.insertCaptureDataNodes(typeGraph.rootTypes());
29+
};
30+
31+
return Pass("KeyCapture", fn);
32+
}
33+
34+
/*
35+
* This function should be used as the main entry point to this pass, to add
36+
* special handling of top-level types.
37+
*/
38+
void KeyCapture::insertCaptureDataNodes(
39+
std::vector<std::reference_wrapper<Type>>& types) {
40+
for (const auto& keyToCapture : keysToCapture_) {
41+
if (!keyToCapture.topLevel)
42+
continue;
43+
44+
// Capture keys from all top-level types
45+
for (size_t i = 0; i < types.size(); i++) {
46+
types[i] = captureKey(types[i]);
47+
}
48+
break;
49+
}
50+
51+
for (const auto& type : types) {
52+
accept(type);
53+
}
54+
}
55+
56+
void KeyCapture::accept(Type& type) {
57+
if (tracker_.visit(type))
58+
return;
59+
60+
type.accept(*this);
61+
}
62+
63+
void KeyCapture::visit(Class& c) {
64+
for (const auto& keyToCapture : keysToCapture_) {
65+
if (!keyToCapture.type.has_value() || c.name() != *keyToCapture.type)
66+
continue;
67+
if (!keyToCapture.member.has_value())
68+
continue;
69+
for (size_t i = 0; i < c.members.size(); i++) {
70+
auto& member = c.members[i];
71+
if (member.name != *keyToCapture.member)
72+
continue;
73+
74+
member = Member{captureKey(member.type()), member};
75+
}
76+
}
77+
78+
RecursiveVisitor::visit(c);
79+
}
80+
81+
/*
82+
* captureKey
83+
*
84+
* If the given type is a container, insert a CaptureKey node above it.
85+
* Otherwise, just return the container node unchanged.
86+
*
87+
* Before:
88+
* Container: std::map
89+
* Param
90+
* [KEY]
91+
* Param
92+
* [VAL]
93+
*
94+
* After:
95+
* CaptureKeys
96+
* Container: std::map
97+
* Param
98+
* [KEY]
99+
* Param
100+
* [VAL]
101+
*/
102+
Type& KeyCapture::captureKey(Type& type) {
103+
auto* container = dynamic_cast<Container*>(&type);
104+
if (!container) // We only want to capture keys from containers
105+
return type;
106+
107+
/*
108+
* Create a copy of the container info for capturing keys.
109+
* CodeGen and other places may deduplicate containers based on the container
110+
* info object, so it is necessary to create a new one when we want different
111+
* behaviour.
112+
*/
113+
auto newContainerInfo = container->containerInfo_.clone();
114+
newContainerInfo.captureKeys = true;
115+
auto infoPtr = std::make_unique<ContainerInfo>(std::move(newContainerInfo));
116+
const auto& info = containerInfos_.emplace_back(std::move(infoPtr));
117+
118+
auto& captureKeysNode = typeGraph_.makeType<CaptureKeys>(*container, *info);
119+
return captureKeysNode;
120+
}
121+
122+
} // namespace oi::detail::type_graph

oi/type_graph/KeyCapture.h

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
#pragma once
17+
18+
#include <functional>
19+
#include <memory>
20+
#include <vector>
21+
22+
#include "NodeTracker.h"
23+
#include "PassManager.h"
24+
#include "Types.h"
25+
#include "Visitor.h"
26+
#include "oi/OICodeGen.h"
27+
28+
namespace oi::detail::type_graph {
29+
30+
/*
31+
* KeyCapture
32+
*
33+
* Marks containers for which the user has requested key-capture.
34+
*/
35+
class KeyCapture : public RecursiveVisitor {
36+
public:
37+
static Pass createPass(
38+
const std::vector<OICodeGen::Config::KeyToCapture>& keysToCapture,
39+
std::vector<std::unique_ptr<ContainerInfo>>& containerInfos);
40+
41+
KeyCapture(NodeTracker& tracker,
42+
TypeGraph& typeGraph,
43+
const std::vector<OICodeGen::Config::KeyToCapture>& keysToCapture,
44+
std::vector<std::unique_ptr<ContainerInfo>>& containerInfos)
45+
: tracker_(tracker),
46+
typeGraph_(typeGraph),
47+
keysToCapture_(keysToCapture),
48+
containerInfos_(containerInfos) {
49+
}
50+
51+
using RecursiveVisitor::accept;
52+
using RecursiveVisitor::visit;
53+
54+
void insertCaptureDataNodes(std::vector<std::reference_wrapper<Type>>& types);
55+
void visit(Class& c) override;
56+
57+
private:
58+
NodeTracker& tracker_;
59+
TypeGraph& typeGraph_;
60+
const std::vector<OICodeGen::Config::KeyToCapture>& keysToCapture_;
61+
std::vector<std::unique_ptr<ContainerInfo>>& containerInfos_;
62+
63+
void accept(Type& type) override;
64+
Type& captureKey(Type& type);
65+
};
66+
67+
} // namespace oi::detail::type_graph

oi/type_graph/Types.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,14 @@ class Member {
109109
bitsize(bitsize) {
110110
}
111111

112+
Member(Type& type, const Member& other)
113+
: type_(type),
114+
name(other.name),
115+
inputName(other.inputName),
116+
bitOffset(other.bitOffset),
117+
bitsize(other.bitsize) {
118+
}
119+
112120
Type& type() const {
113121
return type_;
114122
}
@@ -299,6 +307,17 @@ class Container : public Type {
299307
id_(id) {
300308
}
301309

310+
Container(NodeId id,
311+
const Container& other,
312+
const ContainerInfo& containerInfo)
313+
: templateParams(other.templateParams),
314+
containerInfo_(containerInfo),
315+
name_(other.name_),
316+
inputName_(other.inputName_),
317+
size_(other.size_),
318+
id_(id) {
319+
}
320+
302321
static inline constexpr bool has_node_id = true;
303322

304323
DECLARE_ACCEPT

test/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ add_executable(test_type_graph
4343
test_drgn_parser.cpp
4444
test_enforce_compatibility.cpp
4545
test_flattener.cpp
46+
test_key_capture.cpp
4647
test_name_gen.cpp
4748
test_node_tracker.cpp
4849
test_prune.cpp

test/test_key_capture.cpp

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
#include <gtest/gtest.h>
2+
3+
#include "oi/type_graph/KeyCapture.h"
4+
#include "oi/type_graph/Types.h"
5+
#include "test/type_graph_utils.h"
6+
7+
using namespace type_graph;
8+
9+
TEST(KeyCaptureTest, InClass) {
10+
std::vector<OICodeGen::Config::KeyToCapture> keysToCapture = {
11+
{"MyClass", "b"},
12+
};
13+
std::vector<std::unique_ptr<ContainerInfo>> containerInfos;
14+
test(KeyCapture::createPass(keysToCapture, containerInfos), R"(
15+
[0] Class: MyClass (size: 12)
16+
Member: a (offset: 0)
17+
Primitive: int32_t
18+
Member: b (offset: 8)
19+
[1] Container: std::map (size: 24)
20+
Param
21+
Primitive: int32_t
22+
Param
23+
Primitive: int32_t
24+
Member: c (offset: 8)
25+
[1]
26+
)",
27+
R"(
28+
[0] Class: MyClass (size: 12)
29+
Member: a (offset: 0)
30+
Primitive: int32_t
31+
Member: b (offset: 8)
32+
CaptureKeys
33+
[1] Container: std::map (size: 24)
34+
Param
35+
Primitive: int32_t
36+
Param
37+
Primitive: int32_t
38+
Member: c (offset: 8)
39+
[1]
40+
)");
41+
}
42+
43+
TEST(KeyCaptureTest, MapInMap) {
44+
std::vector<OICodeGen::Config::KeyToCapture> keysToCapture = {
45+
{"MyClass", "a"},
46+
};
47+
std::vector<std::unique_ptr<ContainerInfo>> containerInfos;
48+
test(KeyCapture::createPass(keysToCapture, containerInfos), R"(
49+
[0] Class: MyClass (size: 12)
50+
Member: a (offset: 8)
51+
[1] Container: std::map (size: 24)
52+
Param
53+
Primitive: int32_t
54+
Param
55+
[2] Container: std::map (size: 24)
56+
Param
57+
Primitive: int32_t
58+
Param
59+
Primitive: int32_t
60+
)",
61+
R"(
62+
[0] Class: MyClass (size: 12)
63+
Member: a (offset: 8)
64+
CaptureKeys
65+
[1] Container: std::map (size: 24)
66+
Param
67+
Primitive: int32_t
68+
Param
69+
[2] Container: std::map (size: 24)
70+
Param
71+
Primitive: int32_t
72+
Param
73+
Primitive: int32_t
74+
)");
75+
}
76+
77+
TEST(KeyCaptureTest, TopLevel) {
78+
std::vector<OICodeGen::Config::KeyToCapture> keysToCapture = {
79+
{{}, {}, true},
80+
};
81+
std::vector<std::unique_ptr<ContainerInfo>> containerInfos;
82+
test(KeyCapture::createPass(keysToCapture, containerInfos), R"(
83+
[0] Container: std::map (size: 24)
84+
Param
85+
Primitive: int32_t
86+
Param
87+
Primitive: int32_t
88+
)",
89+
R"(
90+
CaptureKeys
91+
[0] Container: std::map (size: 24)
92+
Param
93+
Primitive: int32_t
94+
Param
95+
Primitive: int32_t
96+
)");
97+
}
98+
99+
TEST(KeyCaptureTest, TopLevelNotCaptured) {
100+
std::vector<OICodeGen::Config::KeyToCapture> keysToCapture = {
101+
{"MyClass", "a"},
102+
};
103+
std::vector<std::unique_ptr<ContainerInfo>> containerInfos;
104+
testNoChange(KeyCapture::createPass(keysToCapture, containerInfos), R"(
105+
[0] Container: std::map (size: 24)
106+
Param
107+
Primitive: int32_t
108+
Param
109+
Primitive: int32_t
110+
)");
111+
}

0 commit comments

Comments
 (0)