Skip to content

Commit c3082f5

Browse files
committed
refactor: refactor the NodeExecutor logic
Signed-off-by: Bo Wang <[email protected]>
1 parent 4cc3143 commit c3082f5

File tree

5 files changed

+181
-158
lines changed

5 files changed

+181
-158
lines changed

core/compiler.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ GraphAndMapping ConstructFallbackGraph_(
228228
std::unordered_map<const torch::jit::Value*, torch::jit::IValue> example_tensor_map) {
229229
auto new_g = std::make_shared<torch::jit::Graph>();
230230

231-
auto segmented_blocks = partitioning::partition(partitioning_ctx, block, example_tensor_map);
231+
auto segmented_blocks = partitioning::Partition(partitioning_ctx, block, example_tensor_map);
232232

233233
// the mapping from lowering graph => fallback global graph
234234
std::unordered_map<torch::jit::Value*, torch::jit::Value*> old_to_new_g;

core/partitioning/partitioning.cpp

Lines changed: 164 additions & 155 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
2+
3+
4+
15
#include <queue>
26

37
#include "torch/csrc/jit/passes/constant_pooling.h"
@@ -31,6 +35,136 @@ bool containNonTensorOutputs(torch::jit::Node* n) {
3135
return false;
3236
}
3337

38+
39+
40+
// Check if the inputs and outputs of the graph are Tensor. If not, then fallback connected nodes
41+
void SetInputsOutputsConnectedNodes(PartitioningCtx* ctx, torch::jit::Block* block) {
42+
// fallback nodes that produce entire graph's nonTensor output
43+
for (auto i : block->outputs()) {
44+
if (!isTensor(i)) {
45+
ctx->setNodeExecutorDecision(i->node(), NodeExecutorDecision::kNON_TENSOR);
46+
}
47+
}
48+
49+
// fallback nodes that consume entire graph's nonTensor input
50+
for (auto i : block->inputs()) {
51+
if (!isTensor(i)) {
52+
for (auto use : i->uses()) {
53+
ctx->setNodeExecutorDecision(use.user, NodeExecutorDecision::kNON_TENSOR);
54+
}
55+
}
56+
}
57+
}
58+
59+
// Find and set all explicit fallback nodes (nodes that are unsupported or forced fallback)
60+
// we use a map to indicate the reason why it's fallback to torch
61+
// For any node that's not explicitly fallback, we set it to run in TensorRT for now
62+
void SetExplicitFallbackNodes(PartitioningCtx* ctx, torch::jit::Block* block) {
63+
auto nodes = block->nodes();
64+
const auto to_compile_sym = c10::Symbol::attr("to_compile");
65+
66+
for (const auto n : nodes) {
67+
if (n->kind() == torch::jit::prim::Constant) {
68+
continue;
69+
}
70+
71+
if (!conversion::OpSupported(n)) {
72+
// If the op is not supported by the conversion phase it should run in PyTorch
73+
ctx->setNodeExecutorDecision(n, NodeExecutorDecision::kUNSUPPORTED);
74+
} else if (ctx->forced_fallback_ops.find(n->kind().toQualString()) != ctx->forced_fallback_ops.end()) {
75+
// If the user specifies the op to run in Torch it should run in PyTorch
76+
ctx->setNodeExecutorDecision(n, NodeExecutorDecision::kOPERATOR_FALLBACK);
77+
} else if (n->hasAttribute(to_compile_sym) && n->i(to_compile_sym) == (int64_t) false) {
78+
// If the user specifies the module containing this op to run in torch it should run in PyTorch
79+
ctx->setNodeExecutorDecision(n, NodeExecutorDecision::kMODULE_FALLBACK);
80+
} else {
81+
// Set the rest nodes to TensorRt
82+
ctx->setNodeExecutorDecision(n, NodeExecutorDecision::kCONVERT);
83+
}
84+
}
85+
return;
86+
}
87+
88+
// For a given set of fallback nodes, check their inputs/outputs, if any inputs/outputs of them are NonTensor,
89+
// then the nodes that produces/consumes those values should also fallback
90+
void SetNonTensorConnectedNodes(PartitioningCtx* ctx, std::vector<torch::jit::Node*>& initial_fallback_nodes) {
91+
// initial_fallback_nodes are the fallback nodes that we have before we run BFS in this function
92+
std::queue<torch::jit::Node*> q;
93+
for (auto& node : initial_fallback_nodes) {
94+
q.push(node.first);
95+
}
96+
97+
while (!q.empty()) {
98+
auto cur_node = q.front();
99+
q.pop();
100+
// for every node that produces this fallback node's NonTensor input, they should fallback too
101+
for (auto input : cur_node->inputs()) {
102+
if (!isTensor(input) && input->node()->kind() != torch::jit::prim::Constant &&
103+
ctx->shouldNodeRunInTensorRT(input->node())) {
104+
ctx->setNodeExecutorDecision(input->node(), NodeExecutorDecision::kNON_TENSOR);
105+
q.push(input->node());
106+
}
107+
}
108+
// for every node that consumes this fallback node's NonTensor output, they should fallback too
109+
for (auto output : cur_node->outputs()) {
110+
if (!isTensor(output)) {
111+
for (auto use : output->uses()) {
112+
auto node = use.user;
113+
if (node->kind() != torch::jit::prim::Constant &&
114+
ctx->shouldNodeRunInTensorRT(node)) {
115+
ctx->setNodeExecutorDecision(node, NodeExecutorDecision::kNON_TENSOR);
116+
q.push(node);
117+
}
118+
}
119+
}
120+
}
121+
}
122+
}
123+
124+
// Sub-function that traverses the entire block and check if TensorRT node sequence satisfy min_block_size
125+
std::vector<torch::jit::Node*> TraverseNodesForMinBlockSize(PartitioningCtx* ctx, torch::jit::Block* block) {
126+
auto nodes = block->nodes();
127+
std::vector<torch::jit::Node*> cur_trt_nodes;
128+
std::vector<torch::jit::Node*> min_block_fallback_nodes;
129+
for (const auto n : nodes) {
130+
if (n->kind() == torch::jit::prim::Constant) {
131+
continue;
132+
}
133+
134+
// check if current node fallback or not
135+
if (!ctx->shouldNodeRunInTorch(n)) {
136+
cur_trt_nodes.push_back(n);
137+
} else {
138+
if (cur_trt_nodes.size() < ctx->settings.min_block_size) {
139+
min_block_fallback_nodes.insert(min_block_fallback_nodes.end(), cur_trt_nodes.begin(), cur_trt_nodes.end());
140+
}
141+
cur_trt_nodes.clear();
142+
}
143+
}
144+
if (cur_trt_nodes.size() < ctx->settings.min_block_size) {
145+
min_block_fallback_nodes.insert(min_block_fallback_nodes.end(), cur_trt_nodes.begin(), cur_trt_nodes.end());
146+
}
147+
return min_block_fallback_nodes;
148+
}
149+
150+
151+
// Set the nodes that fallback because of min_block_size
152+
void SetMinBlockFallbackNodes(PartitioningCtx* ctx, torch::jit::Block* block) {
153+
// first traverse all the nodes to find the initial nodes that don't meet the min_block_size requirement
154+
auto min_block_fallback_nodes = TraverseNodesForMinBlockSize(ctx, block);
155+
156+
// keep fallback until all segments meet the min_block_size requirement
157+
while (!min_block_fallback_nodes.empty()) {
158+
for (const auto i : min_block_fallback_nodes) {
159+
ctx->setNodeExecutorDecision(i, NodeExecutorDecision::kMIN_BLOCK_FALLBACK);
160+
}
161+
// find the fallback nodes because of dependency with min_block_size caused fallback nodes
162+
SetNonTensorConnectedNodes(ctx, min_block_fallback_nodes);
163+
// keep traverse the graph until there is no node fallback because of min_block_size
164+
min_block_fallback_nodes = TraverseNodesForMinBlockSize(ctx, block);
165+
}
166+
}
167+
34168
bool isModifyingNodes(torch::jit::Node* node, torch::jit::Value* val) {
35169
const torch::jit::FunctionSchema* schema = node->maybeSchema();
36170
if (!schema) {
@@ -97,62 +231,6 @@ std::vector<torch::jit::Node*> getDependencyNodes(
97231
return stk;
98232
}
99233

100-
// check if the input and output of the graph is Tensor after collection is enabled. If it is, then fallback related
101-
// nodes
102-
void fallback_graph_nontensor_in_out(PartitioningCtx* ctx, torch::jit::Block* block) {
103-
// fallback nodes that produce entire graph's nonTensor output
104-
for (auto i : block->outputs()) {
105-
if (!isTensor(i)) {
106-
ctx->setNodeExecutorDecision(i->node(), NodeExecutorDecision::kNON_TENSOR);
107-
}
108-
}
109-
110-
// fallback nodes that consume entire graph's nonTensor input
111-
for (auto i : block->inputs()) {
112-
if (!isTensor(i)) {
113-
for (auto use : i->uses()) {
114-
ctx->setNodeExecutorDecision(use.user, NodeExecutorDecision::kNON_TENSOR);
115-
}
116-
}
117-
}
118-
}
119-
120-
void find_all_fallback_nodes(PartitioningCtx* ctx, NodeExecutorDecisionMap& initial_fallback_nodes) {
121-
// initial_fallback_nodes are the fallback nodes that we have before we run BFS in this function
122-
// global_fallback_nodes are the fallback nodes that we maintain globally
123-
std::queue<torch::jit::Node*> q;
124-
for (auto& node : initial_fallback_nodes) {
125-
q.push(node.first);
126-
}
127-
128-
std::unordered_set<torch::jit::Node*> visited_nodes;
129-
while (!q.empty()) {
130-
auto cur_node = q.front();
131-
q.pop();
132-
// for every node that produces this fallback node's NonTensor input, they should fallback too
133-
for (auto input : cur_node->inputs()) {
134-
// NOTE: This does not make sense, does this rely on shortciruiting to work right?
135-
if (!isTensor(input) && input->node()->kind() != torch::jit::prim::Constant &&
136-
ctx->setNodeExecutorDecision(input->node(), NodeExecutorDecision::kNON_TENSOR)) {
137-
q.push(input->node());
138-
}
139-
}
140-
// for every node that consumes this fallback node's NonTensor output, they should fallback too
141-
for (auto output : cur_node->outputs()) {
142-
if (!isTensor(output)) {
143-
for (auto use : output->uses()) {
144-
auto node = use.user;
145-
// NOTE: This does not make sense, does this rely on shortciruiting to work right?
146-
if (node->kind() != torch::jit::prim::Constant &&
147-
ctx->setNodeExecutorDecision(node, NodeExecutorDecision::kNON_TENSOR)) {
148-
q.push(node);
149-
}
150-
}
151-
}
152-
}
153-
}
154-
}
155-
156234
void resolveTRTNonTensorInputs(PartitioningCtx* ctx) {
157235
// if a TRT segment has nonTensor Inputs, the nodes that produce this nonTensor Inputs must in another TensorRT engine
158236
// because we have already found the interface between Torch and TRT in segmentation phase
@@ -250,102 +328,10 @@ bool checkLoopEvaluatable(torch::jit::Node* n) {
250328
return compile_to_trt;
251329
}
252330

253-
// use this function to get all initial fallback nodes (nodes that are unsupported or forced fallback)
254-
// we use a map to indicate the reason why it's fallback to torch
255-
void get_fallback_nodes(PartitioningCtx* ctx, torch::jit::Block* block) {
256-
auto nodes = block->nodes();
257-
for (const auto n : nodes) {
258-
if (n->kind() == torch::jit::prim::Constant) {
259-
continue;
260-
}
261-
262-
// If the op is not supported by the conversion phase it should run in PyTorch
263-
if (!conversion::OpSupported(n)) {
264-
ctx->setNodeExecutorDecision(n, NodeExecutorDecision::kUNSUPPORTED);
265-
}
266-
267-
// If the user specifies the op to run in Torch it should run in PyTorch
268-
if (ctx->forced_fallback_ops.find(n->kind().toQualString()) != ctx->forced_fallback_ops.end()) {
269-
ctx->setNodeExecutorDecision(n, NodeExecutorDecision::kOPERATOR_FALLBACK);
270-
}
271-
272-
// If the user specifies the module containing this op to run in torch it should run in PyTorch
273-
const auto to_compile_sym = c10::Symbol::attr("to_compile");
274-
if (n->hasAttribute(to_compile_sym) && n->i(to_compile_sym) == (int64_t) false) {
275-
ctx->setNodeExecutorDecision(n, NodeExecutorDecision::kMODULE_FALLBACK);
276-
}
277-
}
278-
return;
279-
}
280-
281-
std::vector<torch::jit::Node*> traverse_nodes_for_min_block_size(PartitioningCtx* ctx, torch::jit::Block* block) {
282-
auto nodes = block->nodes();
283-
std::vector<torch::jit::Node*> cur_trt_nodes;
284-
std::vector<torch::jit::Node*> min_block_fallback_nodes;
285-
for (const auto n : nodes) {
286-
if (n->kind() == torch::jit::prim::Constant) {
287-
continue;
288-
}
289-
290-
// check if current node fallback or not
291-
if (!ctx->shouldNodeRunInTorch(n)) {
292-
// if this node is not in fallback nodes, then it's in trt segments
293-
cur_trt_nodes.push_back(n);
294-
} else {
295-
if (cur_trt_nodes.size() < ctx->settings.min_block_size) {
296-
min_block_fallback_nodes.insert(min_block_fallback_nodes.end(), cur_trt_nodes.begin(), cur_trt_nodes.end());
297-
}
298-
cur_trt_nodes.clear();
299-
}
300-
}
301-
if (cur_trt_nodes.size() < ctx->settings.min_block_size) {
302-
min_block_fallback_nodes.insert(min_block_fallback_nodes.end(), cur_trt_nodes.begin(), cur_trt_nodes.end());
303-
}
304-
return min_block_fallback_nodes;
305-
}
306-
307-
void find_min_block_size_fallback_nodes(PartitioningCtx* ctx, torch::jit::Block* block) {
308-
// first traverse all the nodes to find the initial nodes that don't meet the min_block_size requirement
309-
auto min_block_fallback_nodes = traverse_nodes_for_min_block_size(ctx, block);
310-
NodeExecutorDecisionMap initial_fallback_nodes;
311-
312-
// keep fallback until all segments meet the min_block_size requirement
313-
while (!min_block_fallback_nodes.empty()) {
314-
for (const auto i : min_block_fallback_nodes) {
315-
initial_fallback_nodes.insert({i, NodeExecutorDecision::kMIN_BLOCK_FALLBACK});
316-
ctx->setNodeExecutorDecision(i, NodeExecutorDecision::kMIN_BLOCK_FALLBACK);
317-
}
318-
// find the fallback nodes because of dependency with min_block_size caused fallback nodes
319-
find_all_fallback_nodes(ctx, initial_fallback_nodes);
320-
// keep traverse the graph until there is no node fallback because of min_block_size
321-
min_block_fallback_nodes = traverse_nodes_for_min_block_size(ctx, block);
322-
}
323-
}
324-
325-
void segment_graph(PartitioningCtx* ctx, torch::jit::Block* block) {
326-
// get the initial fallback nodes (nodes that are unsupported or forced fallback)
327-
get_fallback_nodes(ctx, block);
328-
329-
// For fallback nodes, if it consumes any NonTensor inputs or TensorList inputs, then the node that produces this
330-
// input should also fallback Similarly, if it produces any NonTensor outputs or TensorList outputs, then the node
331-
// that produces this input should also fallback
332-
// TODO: don't need to fallback the TensorList related nodes once the collection feature is supported
333-
find_all_fallback_nodes(ctx, ctx->node_executor_decision_map);
334-
335-
// find all fallback nodes because of the min_block_size requirement
336-
find_min_block_size_fallback_nodes(ctx, block);
337331

332+
void SegmentGraph(PartitioningCtx* ctx, torch::jit::Block* block) {
338333
auto nodes = block->nodes();
339334

340-
// NOTE: Realize this may be redundant, but will let us have an explicit state for each node. Maybe there is a better
341-
// way for (auto n : nodes) {
342-
// if (!ctx->shouldNodeRunInTorch(n) && !ctx->isNodeExecutorKnown(n)) {
343-
// if (conversion::OpSupported(n)) {
344-
// ctx->setNodeExecutorDecision(n, NodeExecutorDecision::kCONVERT);
345-
// }
346-
// }
347-
// }
348-
349335
// segment the nodes
350336
std::vector<torch::jit::Node*> in_prog_trt_blk_nodes, in_prog_pyt_blk_nodes;
351337
for (const auto n : nodes) {
@@ -420,18 +406,41 @@ void segment_graph(PartitioningCtx* ctx, torch::jit::Block* block) {
420406
return;
421407
}
422408

423-
PartitionedGraph partition(PartitioningCtx* ctx, torch::jit::Block* block, ExampleIValues& example_tensor_map) {
409+
void SetNodeExecutorDecision(PartitioningCtx* ctx, torch::jit::Block* block) {
410+
// First, find all the explicit fallback nodes that should run in Torch:
411+
// 1. nodes that are unsupported
412+
// 2. nodes that the user specifies to run in torch
413+
// 3. nodes that the user specifies the module containing this op to run in torch
414+
// At the same time, set all the rest nodes to NodeExecutorDecision::kCONVERT
415+
SetExplicitFallbackNodes(ctx, block);
416+
417+
// Second, check if there is nonTensor input/output for the block, if there is, then fallback the nodes that
418+
// consume/produce this nonTensor value
419+
SetInputsOutputsConnectedNodes(ctx, block);
420+
421+
// Third, for fallback nodes, if it consumes any NonTensor inputs, then the nodes that produce this
422+
// input should also fallback. Similarly, if it produces any NonTensor outputs, then the nodes
423+
// that consume this output should also fallback
424+
auto cur_fallback_nodes = ctx->getNodesRunInTorch();
425+
SetNonTensorConnectedNodes(ctx, cur_fallback_nodes);
426+
427+
// Finally, check if all current tensorrt blocks satisfy the min_block_size requirement.
428+
// We need to traverse the whole graph many times here
429+
SetMinBlockFallbackNodes(ctx, block);
430+
}
431+
432+
433+
434+
PartitionedGraph Partition(PartitioningCtx* ctx, torch::jit::Block* block, ExampleIValues& example_tensor_map) {
424435
LOG_DEBUG(ctx->settings);
425-
// if there is nonTensor input/output for the entire graph, fallback the node that consumes/produces this nonTensor
426-
// output
427-
fallback_graph_nontensor_in_out(ctx, block);
436+
437+
SetNodeExecutorDecision(ctx, block);
428438

429439
// segment lowering global graph into blocks
430440
LOG_DEBUG("Parititioning source module into PyTorch and TensorRT sub blocks");
431-
segment_graph(ctx, block);
441+
SegmentGraph(ctx, block);
432442

433443
// It's possible that some TensorRT blocks have nonTensor inputs/output because they are interleaved by Torch blocks
434-
435444
// resolve nonTensor inputs/outputs
436445
resolveTRTNonTensorInputs(ctx);
437446

core/partitioning/partitioning.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ void runShapeAnalysis(PartitioningCtx* ctx, ExampleIValues& ivalues_maps);
2323

2424
void segment_graph(PartitioningCtx* ctx, torch::jit::Block* block);
2525

26-
PartitionedGraph partition(PartitioningCtx* ctx, torch::jit::Block* block, ExampleIValues& example_tensor_map);
26+
PartitionedGraph Partition(PartitioningCtx* ctx, torch::jit::Block* block, ExampleIValues& example_tensor_map);
2727

2828
} // namespace partitioning
2929
} // namespace core

0 commit comments

Comments
 (0)