Skip to content

Commit 725c327

Browse files
dolpmpytorchmergebot
authored andcommitted
[nativert] add memory overlap debug assertion (pytorch#157290)
Summary: better safe than sorry. will throw if memory overlap detected when using planned tensors and debug mode is enabled -- this will make our planning unit tests more robust. Test Plan: ci Rollback Plan: Differential Revision: D77327841 Pull Request resolved: pytorch#157290 Approved by: https://github.com/SherlockNoMad, https://github.com/zhxchen17
1 parent f87d117 commit 725c327

19 files changed

+604
-76
lines changed

test/cpp/nativert/CMakeLists.txt

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,15 @@ set(NATIVERT_TEST_SRCS
2424
${TORCH_ROOT}/torch/nativert/executor/memory/LayoutPlanner.cpp
2525
${TORCH_ROOT}/torch/nativert/executor/memory/LayoutManager.cpp
2626
${TORCH_ROOT}/torch/nativert/executor/memory/AliasAnalyzer.cpp
27+
${TORCH_ROOT}/torch/nativert/executor/Executor.cpp
28+
${TORCH_ROOT}/torch/nativert/kernels/KernelFactory.cpp
29+
${TORCH_ROOT}/torch/nativert/executor/ConstantFolder.cpp
30+
${TORCH_ROOT}/torch/nativert/executor/GraphExecutorBase.cpp
31+
${TORCH_ROOT}/torch/nativert/executor/SerialGraphExecutor.cpp
32+
${TORCH_ROOT}/torch/nativert/executor/ParallelGraphExecutor.cpp
33+
${TORCH_ROOT}/torch/nativert/kernels/AutoFunctionalizeKernel.cpp
34+
${TORCH_ROOT}/torch/nativert/kernels/CallTorchBindKernel.cpp
35+
${TORCH_ROOT}/torch/nativert/kernels/HigherOrderKernel.cpp
2736
)
2837

2938
add_executable(test_nativert
Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
#include <gtest/gtest.h>
2+
3+
#include <fmt/format.h>
4+
5+
#include <torch/nativert/executor/memory/AliasAnalyzer.h>
6+
#include <torch/nativert/graph/Graph.h>
7+
8+
#include <torch/nativert/executor/Executor.h>
9+
#include <torch/nativert/kernels/KernelFactory.h>
10+
11+
using namespace ::testing;
12+
using namespace torch::nativert;
13+
14+
using AliasTestCase = std::tuple<
15+
std::string /* value */,
16+
AllocationLifetime,
17+
bool /* is_alias */,
18+
bool /* is_storage_associated_with_output */,
19+
c10::FastSet<std::string> /* source(s) */>;
20+
21+
class AliasAnalyzerTests : public testing::Test {
22+
void SetUp() override {}
23+
24+
void TearDown() override {
25+
test_cases.clear();
26+
model.clear();
27+
}
28+
29+
public:
30+
void setTestCases(std::vector<AliasTestCase> cases) {
31+
test_cases = std::move(cases);
32+
}
33+
34+
void setModel(std::string m) {
35+
model = std::move(m);
36+
}
37+
38+
void run() {
39+
EXPECT_FALSE(test_cases.empty());
40+
EXPECT_FALSE(model.empty());
41+
42+
ExecutorConfig cfg;
43+
cfg.enableStaticCPUKernels = true;
44+
45+
auto graph = stringToGraph(model);
46+
auto kernels = KernelFactory().initializeNodeKernels(
47+
*graph, nullptr, cfg, {}, nullptr);
48+
auto kernelSchemas = Executor::getKernelSchemas(kernels.nodeKernels);
49+
50+
AliasAnalyzer analyzer(*graph, kernelSchemas);
51+
52+
for (
53+
auto& [value, lifetime, is_alias, is_storage_associated_with_output, srcs] :
54+
test_cases) {
55+
LOG(INFO) << fmt::format(
56+
"running test: value={}, lifetime=({}, {}), is_alias={}, is_storage_associated_with_output={}, src={}",
57+
value,
58+
lifetime.start,
59+
lifetime.end,
60+
is_alias,
61+
is_storage_associated_with_output,
62+
srcs.empty() ? "{}"
63+
: std::accumulate(
64+
srcs.begin(),
65+
srcs.end(),
66+
std::string{},
67+
[](std::string cur, const std::string& src) {
68+
cur.append(",");
69+
cur.append(src);
70+
return cur;
71+
}));
72+
auto* v = graph->getValue(value);
73+
EXPECT_EQ(analyzer.lifetime(v), lifetime);
74+
EXPECT_EQ(analyzer.is_alias(v), is_alias);
75+
EXPECT_EQ(
76+
analyzer.is_storage_associated_with_output(v),
77+
is_storage_associated_with_output);
78+
const auto* resolved_srcs = analyzer.get_sources_of_alias(v);
79+
if (resolved_srcs /* ensure set equality between *resolved_srcs and srcs */) {
80+
EXPECT_FALSE(srcs.empty());
81+
EXPECT_EQ(resolved_srcs->size(), srcs.size());
82+
for (const auto& resolved_src : *resolved_srcs) {
83+
EXPECT_TRUE(srcs.erase(std::string(resolved_src->name())) == 1);
84+
}
85+
EXPECT_TRUE(srcs.empty());
86+
} else {
87+
EXPECT_TRUE(srcs.empty());
88+
}
89+
}
90+
}
91+
92+
private:
93+
std::string model;
94+
std::vector<AliasTestCase> test_cases;
95+
};
96+
97+
TEST_F(AliasAnalyzerTests, TestNoAlias) {
98+
setModel(R"(
99+
graph(%y0, %y1):
100+
%out_t = torch.ops.aten.matmul.default(self=%y0, other=%y1)
101+
%res = torch.ops.aten.clone.default(self=%out_t, memory_format=None)
102+
return (%res))");
103+
104+
setTestCases({
105+
{"out_t", AllocationLifetime(1, 2), false, false, {}},
106+
{"res", AllocationLifetime(2, 3), false, true, {}},
107+
});
108+
109+
run();
110+
}
111+
112+
TEST_F(AliasAnalyzerTests, TestSimpleAlias) {
113+
setModel(R"(
114+
graph(%y0, %y1):
115+
%out_t = torch.ops.aten.matmul.default(self=%y0, other=%y1)
116+
%res = torch.ops.aten.slice.Tensor(self=%out_t, dim=1, start=0, end=0, step=1)
117+
return (%res))");
118+
119+
setTestCases({
120+
{"out_t", AllocationLifetime(1, 3), false, true, {}},
121+
{"res", AllocationLifetime(2, 3), true, false, {"out_t"}},
122+
});
123+
124+
run();
125+
}
126+
127+
TEST_F(AliasAnalyzerTests, TestDeepAlias) {
128+
setModel(R"(
129+
graph(%y0, %y1):
130+
%out_t = torch.ops.aten.matmul.default(self=%y0, other=%y1)
131+
%a1 = torch.ops.aten.slice.Tensor(self=%out_t, dim=1, start=0, end=0, step=1)
132+
%res = torch.ops.aten.slice.Tensor(self=%a1, dim=1, start=0, end=0, step=1)
133+
return (%res))");
134+
135+
setTestCases({
136+
{"out_t", AllocationLifetime(1, 4), false, true, {}},
137+
{"a1", AllocationLifetime(2, 4), true, false, {"out_t"}},
138+
{"res", AllocationLifetime(3, 4), true, false, {"out_t"}},
139+
});
140+
141+
run();
142+
}
143+
144+
TEST_F(AliasAnalyzerTests, TestPackedListUnpack) {
145+
setModel(R"(
146+
graph(%a, %b, %c, %d):
147+
%input_list[] = prim.ListPack(l0=%a, l1=%b, l2=%c, l3=%d)
148+
%x0, %x1, %x2, %x3 = prim.ListUnpack(input=%input_list)
149+
return (%x1, %x3))");
150+
151+
setTestCases({
152+
{"a", AllocationLifetime(0, 2), false, false, {}},
153+
{"x0", AllocationLifetime(2, 2), true, false, {"a"}},
154+
{"b", AllocationLifetime(0, 3), false, true, {}},
155+
{"x1", AllocationLifetime(2, 3), true, false, {"b"}},
156+
{"c", AllocationLifetime(0, 2), false, false, {}},
157+
{"x2", AllocationLifetime(2, 2), true, false, {"c"}},
158+
{"d", AllocationLifetime(0, 3), false, true, {}},
159+
{"x3", AllocationLifetime(2, 3), true, false, {"d"}},
160+
});
161+
162+
run();
163+
}
164+
165+
TEST_F(AliasAnalyzerTests, TestAmbiguousSourceOfAlias) {
166+
setModel(R"(
167+
graph(%y0, %y1):
168+
%out_t = torch.ops.aten.matmul.default(self=%y0, other=%y1)
169+
%out_t2 = torch.ops.aten.matmul.default(self=%y0, other=%y1)
170+
%a1 = prim.VarStack(l0=%out_t, l1=%out_t2)
171+
%res = torch.ops.aten.slice.Tensor(self=%a1, dim=1, start=0, end=0, step=1)
172+
return (%res))");
173+
174+
setTestCases({
175+
{"out_t", AllocationLifetime(1, 5), false, true, {}},
176+
{"out_t2", AllocationLifetime(2, 5), false, true, {}},
177+
{"a1", AllocationLifetime(3, 5), true, false, {"out_t", "out_t2"}},
178+
{"res", AllocationLifetime(4, 5), true, false, {"out_t", "out_t2"}},
179+
});
180+
181+
run();
182+
}

torch/nativert/executor/ExecutionFrame.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,13 +46,14 @@ class ExecutionFrame {
4646
}
4747

4848
template <typename CB>
49-
auto withMemoryPlanner(CB&& cb) {
49+
auto withManagedMemory(CB&& cb) {
5050
if (!layoutManager_) {
51-
return std::forward<CB>(cb)();
51+
return std::forward<CB>(cb)(nullptr);
5252
}
5353

5454
LayoutManagerGuard guard(*layoutManager_);
55-
return std::forward<CB>(cb)();
55+
return std::forward<CB>(cb)(
56+
const_cast<const LayoutManager*>(layoutManager_.get()));
5657
}
5758

5859
std::vector<c10::IValue> tryMoveUserOutputs();

torch/nativert/executor/Executor.cpp

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -19,30 +19,31 @@ namespace torch::nativert {
1919
Executor::Executor(
2020
torch::nativert::ExecutorConfig executorConfig,
2121
std::shared_ptr<Graph> graph,
22-
std::shared_ptr<Weights> weights,
23-
const Placement& placement,
24-
std::shared_ptr<caffe2::serialize::PyTorchStreamReader> pytorchStreamReader,
25-
const MakeProxyExecutorFn& makeProxyExecutorFunc)
22+
const std::shared_ptr<Weights>& weights,
23+
Placement placement,
24+
const std::shared_ptr<caffe2::serialize::PyTorchStreamReader>&
25+
pytorchStreamReader,
26+
MakeProxyExecutorFn makeProxyExecutorFunc)
2627
: executorConfig_(std::move(executorConfig)),
2728
graph_(std::move(graph)),
28-
placement_(placement),
29+
placement_(std::move(placement)),
2930
constantFolder_(
3031
executorConfig_.runConstFolding
3132
? std::optional<ConstantFolder>(*graph_)
3233
: std::nullopt),
33-
makeProxyExecutorFunc_(makeProxyExecutorFunc),
34+
makeProxyExecutorFunc_(std::move(makeProxyExecutorFunc)),
3435
executionFrames_(executorConfig_.maxNumConcurrentThreads),
3536
clearedExecutionFrames_(executorConfig_.maxNumConcurrentThreads),
3637
numExecutionFrames_(0),
3738
lastClearedTimestamp_(getCurrentTimestampSeconds()) {
3839
if (weights) {
39-
initialize(std::move(weights), std::move(pytorchStreamReader));
40+
initialize(weights, pytorchStreamReader);
4041
}
4142
}
4243

4344
void Executor::initialize(
44-
std::shared_ptr<Weights> weights,
45-
std::shared_ptr<caffe2::serialize::PyTorchStreamReader>
45+
const std::shared_ptr<Weights>& weights,
46+
const std::shared_ptr<caffe2::serialize::PyTorchStreamReader>&
4647
pytorchStreamReader) {
4748
auto start = std::chrono::high_resolution_clock::now();
4849

@@ -51,7 +52,7 @@ void Executor::initialize(
5152
weights,
5253
executorConfig_,
5354
placement_,
54-
std::move(pytorchStreamReader),
55+
pytorchStreamReader,
5556
makeProxyExecutorFunc_);
5657

5758
if (constantFolder_.has_value()) {
@@ -113,13 +114,14 @@ void Executor::atomicSwapWeights(std::shared_ptr<Weights> weights) {
113114
}
114115
}
115116

116-
void Executor::maybeRunConstantFolding(std::shared_ptr<Weights> weights) {
117+
void Executor::maybeRunConstantFolding(
118+
const std::shared_ptr<Weights>& weights) {
117119
for (auto& execution : constFoldingExecutions_) {
118120
ExecutionFrame constFoldingFrame(execution.executor->graph());
119121
std::vector<c10::IValue> inputs;
120122
inputs.reserve(graph_->signature().inputsToWeights().size());
121123
for (const auto& [_, name] : graph_->signature().inputsToWeights()) {
122-
inputs.push_back(weights->at(name));
124+
inputs.emplace_back(weights->at(name));
123125
}
124126

125127
auto outputs = execution.executor->execute(constFoldingFrame, inputs);
@@ -130,7 +132,7 @@ void Executor::maybeRunConstantFolding(std::shared_ptr<Weights> weights) {
130132
}
131133
}
132134

133-
void Executor::processWeights(std::shared_ptr<Weights> weights) {
135+
void Executor::processWeights(const std::shared_ptr<Weights>& weights) {
134136
maybeRunConstantFolding(weights);
135137
if (constantFolder_.has_value()) {
136138
constantFolder_->evaluate(*weights);
@@ -352,10 +354,10 @@ std::vector<c10::IValue> Executor::execute(
352354
}
353355

354356
ProfileMetrics Executor::benchmarkIndividualNodes(
355-
std::vector<std::vector<c10::IValue>> inputsList,
357+
const std::vector<std::vector<c10::IValue>>& inputsList,
356358
const uint32_t warmupRuns,
357359
const uint32_t mainRuns) {
358-
CHECK(inputsList.size() > 0) << "Need at least one input to benchmark";
360+
CHECK(!inputsList.empty()) << "Need at least one input to benchmark";
359361
CHECK(warmupRuns >= 1 && mainRuns >= 1) << "Need at least one run";
360362

361363
for (const auto& inputs : inputsList) {
@@ -378,8 +380,9 @@ int64_t Executor::getCurrentTimestampSeconds() const {
378380

379381
std::vector<DelegateExecutor*> Executor::getDelegates() {
380382
std::vector<DelegateExecutor*> delegates;
383+
delegates.reserve(delegateExecutors_.size());
381384
for (const auto& delegateExecutor : delegateExecutors_) {
382-
delegates.push_back(delegateExecutor.get());
385+
delegates.emplace_back(delegateExecutor.get());
383386
}
384387
return delegates;
385388
}

torch/nativert/executor/Executor.h

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -79,19 +79,19 @@ class Executor {
7979
Executor(
8080
torch::nativert::ExecutorConfig executorConfig,
8181
std::shared_ptr<Graph> graph,
82-
std::shared_ptr<Weights> weights,
83-
const Placement& placement = Placement(),
84-
std::shared_ptr<caffe2::serialize::PyTorchStreamReader>
82+
const std::shared_ptr<Weights>& weights,
83+
Placement placement = Placement(),
84+
const std::shared_ptr<caffe2::serialize::PyTorchStreamReader>&
8585
pytorchStreamReader = nullptr,
86-
const MakeProxyExecutorFn& makeProxyExecutorFunc = nullptr);
86+
MakeProxyExecutorFn makeProxyExecutorFunc = nullptr);
8787

8888
std::shared_ptr<Weights> getWeights() {
8989
std::shared_ptr<Weights> ret;
9090
weights_.withLock([&](auto& w) { ret = w; });
9191
return ret;
9292
}
9393

94-
void processWeights(std::shared_ptr<Weights> weights);
94+
void processWeights(const std::shared_ptr<Weights>& weights);
9595
void atomicSwapWeights(std::shared_ptr<Weights> weights);
9696

9797
// This API only returns the flattened UserOutputs,
@@ -106,7 +106,7 @@ class Executor {
106106
const ITreeSpec& inputTreeSpec);
107107

108108
ProfileMetrics benchmarkIndividualNodes(
109-
std::vector<std::vector<c10::IValue>> inputsList,
109+
const std::vector<std::vector<c10::IValue>>& inputsList,
110110
const uint32_t warmupRuns,
111111
const uint32_t mainRuns);
112112

@@ -141,8 +141,8 @@ class Executor {
141141
c10::Synchronized<std::shared_ptr<Weights>> weights_;
142142

143143
void initialize(
144-
std::shared_ptr<Weights> weights,
145-
std::shared_ptr<caffe2::serialize::PyTorchStreamReader>
144+
const std::shared_ptr<Weights>& weights,
145+
const std::shared_ptr<caffe2::serialize::PyTorchStreamReader>&
146146
pytorchStreamReader);
147147

148148
ExecutorFramePtr getExecutorFrameFromPool();
@@ -171,7 +171,7 @@ class Executor {
171171
ExecutionFrameEntry& operator=(const ExecutionFrameEntry&) = delete;
172172
};
173173

174-
void maybeRunConstantFolding(std::shared_ptr<Weights> weights);
174+
void maybeRunConstantFolding(const std::shared_ptr<Weights>& weights);
175175
void validateInputs(const std::vector<c10::IValue>& inputs) const;
176176

177177
// Helper method to get current timestamp in seconds

torch/nativert/executor/GraphExecutorBase.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ void GraphExecutorBase::fillUserInputs(
3232

3333
ProfileMetrics GraphExecutorBase::benchmarkIndividualNodes(
3434
ExecutionFrame& executionFrame,
35-
std::vector<std::vector<c10::IValue>> inputsList,
35+
const std::vector<std::vector<c10::IValue>>& inputsList,
3636
const uint32_t warmupRuns,
3737
const uint32_t mainRuns) {
3838
// TODO: add support for memory profiling
@@ -112,7 +112,7 @@ ProfileMetrics GraphExecutorBase::benchmarkIndividualNodes(
112112
results.totalNodesCount = numNodes;
113113
for (const auto& r : results.timePerNodeType) {
114114
const std::string& target = r.first;
115-
results.percentPerNodeType[target] = r.second * 100.0 / results.totalTime;
115+
results.percentPerNodeType[target] = r.second * 100.0f / results.totalTime;
116116
}
117117
return results;
118118
}

torch/nativert/executor/GraphExecutorBase.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ class GraphExecutorBase {
5151

5252
ProfileMetrics benchmarkIndividualNodes(
5353
ExecutionFrame& executionFrame,
54-
std::vector<std::vector<c10::IValue>> inputs,
54+
const std::vector<std::vector<c10::IValue>>& inputs,
5555
const uint32_t warmup_runs,
5656
const uint32_t main_runs);
5757

0 commit comments

Comments
 (0)