Skip to content

Commit 77f4c09

Browse files
committed
fix: fix bugs found when running tests
Signed-off-by: Bo Wang <[email protected]>
1 parent ff22707 commit 77f4c09

File tree

12 files changed

+121
-156
lines changed

12 files changed

+121
-156
lines changed

core/compiler.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,11 +138,11 @@ partitioning::GraphAndMapping BuildHybridGraph(
138138

139139
auto partitioning_ctx = partitioning::PartitioningCtx(block, partitioning_info);
140140
auto collection_input_ivalues_map =
141-
partitioning::generateRandomInputs(partitioning_info.collection_input_spec_map, first_use_types);
141+
partitioning::GenerateRandomInputs(partitioning_info.collection_input_spec_map, first_use_types);
142142

143143
partitioning::Partition(&partitioning_ctx, collection_input_ivalues_map);
144144

145-
for (auto &partitioned_block : partitioning_ctx.partitioned_blocks) {
145+
for (auto& partitioned_block : partitioning_ctx.partitioned_blocks) {
146146
partitioning::PartitionedGraph& segmented_blocks = partitioned_block.second;
147147

148148
for (auto& seg_block : segmented_blocks) {

core/partitioning/BUILD

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ cc_library(
1515
srcs = [
1616
"partitioning.cpp",
1717
"shape_analysis.cpp",
18-
"stitching.cpp"
18+
"stitching.cpp",
1919
],
2020
hdrs = [
2121
"partitioning.h",

core/partitioning/partitioning.cpp

Lines changed: 42 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,6 @@ void SetExplicitFallbackNodes(PartitioningCtx* ctx, torch::jit::Block* block) {
7373
// Set the rest nodes to TensorRt
7474
ctx->setNodeExecutorDecision(n, NodeExecutorDecision::kCONVERT);
7575
}
76-
7776
}
7877
return;
7978
}
@@ -103,7 +102,8 @@ void SetNonTensorConnectedNodes(PartitioningCtx* ctx, std::vector<torch::jit::No
103102
if (!isTensor(output)) {
104103
for (auto use : output->uses()) {
105104
auto node = use.user;
106-
if (node->kind() != torch::jit::prim::Constant && ctx->shouldNodeRunInTensorRT(node)) {
105+
if (node->kind() != torch::jit::prim::Constant && node->kind() != torch::jit::prim::Return &&
106+
ctx->shouldNodeRunInTensorRT(node)) {
107107
ctx->setNodeExecutorDecision(node, NodeExecutorDecision::kNON_TENSOR);
108108
q.push(node);
109109
}
@@ -175,7 +175,7 @@ bool isModifyingNodes(torch::jit::Node* node, torch::jit::Value* val) {
175175
return false;
176176
}
177177

178-
std::vector<torch::jit::Node*> findModifyingNodes(
178+
std::vector<torch::jit::Node*> FindModifyingNodes(
179179
torch::jit::Value* val,
180180
const std::unordered_set<torch::jit::Node*>& seg_block_nodes) {
181181
std::vector<torch::jit::Node*> modifying_nodes;
@@ -192,7 +192,7 @@ std::vector<torch::jit::Node*> findModifyingNodes(
192192
}
193193

194194
// this function is only used when a TRT segment produces nonTensor values which are used by later TRT segment
195-
std::vector<torch::jit::Node*> getDependencyNodes(
195+
std::vector<torch::jit::Node*> GetDependencyNodes(
196196
const std::vector<torch::jit::Value*>& vals,
197197
const SegmentedBlock& seg_block) {
198198
// get all nodes in the segmentedblock
@@ -208,7 +208,7 @@ std::vector<torch::jit::Node*> getDependencyNodes(
208208
auto node = cur_val->node();
209209
if (node->kind() != torch::jit::prim::Constant && !visited.count(node)) {
210210
visited.insert(node);
211-
auto modifying_nodes = findModifyingNodes(cur_val, seg_block_nodes);
211+
auto modifying_nodes = FindModifyingNodes(cur_val, seg_block_nodes);
212212
stk.insert(stk.end(), modifying_nodes.rbegin(), modifying_nodes.rend());
213213
stk.push_back(node);
214214
for (auto input : node->inputs()) {
@@ -222,7 +222,7 @@ std::vector<torch::jit::Node*> getDependencyNodes(
222222
return stk;
223223
}
224224

225-
void resolveTRTNonTensorInputs(PartitioningCtx* ctx, torch::jit::Block* block) {
225+
void ResolveTRTNonTensorInputs(PartitioningCtx* ctx, torch::jit::Block* block) {
226226
// if a TRT segment has nonTensor Inputs, the nodes that produce this nonTensor Inputs must in another TensorRT engine
227227
// because we have already found the interface between Torch and TRT in segmentation phase
228228
// what we do here is just find the dependency nodes of the TRT segments that have nonTensor inputs
@@ -236,16 +236,19 @@ void resolveTRTNonTensorInputs(PartitioningCtx* ctx, torch::jit::Block* block) {
236236
}
237237
}
238238
if (!inputs_to_resolve.empty()) {
239-
std::vector<torch::jit::Node*> dependency_nodes = getDependencyNodes(inputs_to_resolve, cur_partitioned_block[i]);
239+
std::vector<torch::jit::Node*> dependency_nodes =
240+
GetDependencyNodes(inputs_to_resolve, cur_partitioned_block[i]);
240241
dependency_nodes.insert(
241-
dependency_nodes.end(), cur_partitioned_block[i].raw_nodes().begin(), cur_partitioned_block[i].raw_nodes().end());
242+
dependency_nodes.end(),
243+
cur_partitioned_block[i].raw_nodes().begin(),
244+
cur_partitioned_block[i].raw_nodes().end());
242245
cur_partitioned_block[i] = SegmentedBlock(SegmentedBlock::kTensorRT, dependency_nodes);
243246
}
244247
}
245248
}
246249
}
247250

248-
void registerSegmentsOutputs(PartitioningCtx* ctx, torch::jit::Block* block) {
251+
void RegisterSegmentsOutputs(PartitioningCtx* ctx, torch::jit::Block* block) {
249252
// find the corresponding raw values in original global graph for this segmented block's inputs/outputs
250253
PartitionedGraph& cur_partitioned_block = ctx->partitioned_blocks[block];
251254
auto cmp = [](torch::jit::Value* a, torch::jit::Value* b) { return a->unique() < b->unique(); };
@@ -331,21 +334,46 @@ void finalizeNewBlock(
331334
LOG_DEBUG(g.back());
332335
}
333336

337+
void SetNodeExecutorLUT(PartitioningCtx* ctx, torch::jit::Block* block) {
338+
// First, find all the explicit fallback nodes that should run in Torch:
339+
// 1. nodes that are unsupported
340+
// 2. nodes that the user specifies to run in torch
341+
// 3. nodes that the user specifies the module containing this op to run in torch
342+
// At the same time, set all the rest nodes to NodeExecutorDecision::kCONVERT
343+
SetExplicitFallbackNodes(ctx, block);
344+
345+
// Second, check if there is nonTensor input/output for the block, if there is, then fallback the nodes that
346+
// consume/produce this nonTensor value
347+
SetInputsOutputsConnectedNodes(ctx, block);
348+
349+
// Third, for fallback nodes, if it consumes any NonTensor inputs, then the nodes that produce this
350+
// input should also fallback. Similarly, if it produces any NonTensor outputs, then the nodes
351+
// that consume this output should also fallback
352+
auto cur_fallback_nodes = ctx->getNodesRunInTorch();
353+
SetNonTensorConnectedNodes(ctx, cur_fallback_nodes);
354+
355+
// Finally, check if all current tensorrt blocks satisfy the min_block_size requirement.
356+
// We need to traverse the whole graph many times here
357+
SetMinBlockFallbackNodes(ctx, block);
358+
}
359+
334360
void SegmentGraph(PartitioningCtx* ctx, torch::jit::Block* block) {
361+
// Find all the fallback nodes and build execution decision LUT for all nodes
362+
SetNodeExecutorLUT(ctx, block);
363+
335364
auto nodes = block->nodes();
336365

337366
// segment the nodes
338367
PartitionedGraph segmented_blocks;
339368

340369
std::vector<torch::jit::Node*> in_prog_trt_blk_nodes, in_prog_pyt_blk_nodes;
341370
for (const auto n : nodes) {
342-
343371
// Skip constant nodes as they are resources for both kinds of modules
344372
if (n->kind() == torch::jit::prim::Constant) {
345373
continue;
346374
}
347375
// the outputs of trt subgraph shouldn't be collections
348-
if (!ctx->shouldNodeRunInTorch(n)) {
376+
if (ctx->shouldNodeRunInTensorRT(n)) {
349377
in_prog_trt_blk_nodes.push_back(n);
350378

351379
// If there is an active PyTorch block and we have passed the threshold for a valid TRT
@@ -410,65 +438,26 @@ void SegmentGraph(PartitioningCtx* ctx, torch::jit::Block* block) {
410438
return;
411439
}
412440

413-
void SetNodeExecutorLUT(PartitioningCtx* ctx, torch::jit::Block* block) {
414-
// First, find all the explicit fallback nodes that should run in Torch:
415-
// 1. nodes that are unsupported
416-
// 2. nodes that the user specifies to run in torch
417-
// 3. nodes that the user specifies the module containing this op to run in torch
418-
// At the same time, set all the rest nodes to NodeExecutorDecision::kCONVERT
419-
SetExplicitFallbackNodes(ctx, block);
420-
421-
// Second, check if there is nonTensor input/output for the block, if there is, then fallback the nodes that
422-
// consume/produce this nonTensor value
423-
SetInputsOutputsConnectedNodes(ctx, block);
424-
425-
// Third, for fallback nodes, if it consumes any NonTensor inputs, then the nodes that produce this
426-
// input should also fallback. Similarly, if it produces any NonTensor outputs, then the nodes
427-
// that consume this output should also fallback
428-
auto cur_fallback_nodes = ctx->getNodesRunInTorch();
429-
SetNonTensorConnectedNodes(ctx, cur_fallback_nodes);
430-
431-
// Finally, check if all current tensorrt blocks satisfy the min_block_size requirement.
432-
// We need to traverse the whole graph many times here
433-
SetMinBlockFallbackNodes(ctx, block);
434-
}
435-
436441
void Partition(PartitioningCtx* ctx, ExampleIValues& example_tensor_map) {
437442
LOG_DEBUG(ctx->settings);
438443

439444
// Go through all the blocks to do the partitioning
440445
for (torch::jit::Block* block : ctx->original_blocks) {
441-
442-
// Find all the fallback nodes and build execution decision LUT for all nodes
443-
SetNodeExecutorLUT(ctx, block);
444-
445446
// segment lowering global graph into blocks
446447
SegmentGraph(ctx, block);
447448

448449
// It's possible that some TensorRT blocks have nonTensor inputs/output because they are interleaved by Torch blocks
449450
// resolve nonTensor inputs/outputs
450-
resolveTRTNonTensorInputs(ctx, block);
451+
ResolveTRTNonTensorInputs(ctx, block);
451452

452453
// register input/output torch::jit::Value for segmented graphs
453454
LOG_DEBUG("Registering input/output torch::jit::Value for segmented graphs");
454-
registerSegmentsOutputs(ctx, block);
455+
RegisterSegmentsOutputs(ctx, block);
455456

456-
for (auto &i : ctx->partitioned_blocks[block]) {
457-
LOG_DEBUG(i);
458-
}
459457

460458
// run shape analysis on each segmented block
461-
runShapeAnalysis(ctx, block, example_tensor_map);
462-
459+
RunShapeAnalysis(ctx, block, example_tensor_map);
463460
}
464-
465-
466-
467-
// for (uint64_t i = 0; i < ctx->blocks.size(); i++) {
468-
// ctx->blocks[i].update_id(i);
469-
// }
470-
471-
472461
}
473462

474463
} // namespace partitioning

core/partitioning/partitioning.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,11 @@ typedef std::unordered_map<const torch::jit::Value*, torch::jit::IValue> Example
2020
typedef std::pair<std::shared_ptr<torch::jit::Graph>, std::unordered_map<torch::jit::Value*, torch::jit::Value*>>
2121
GraphAndMapping;
2222

23-
ExampleIValues generateRandomInputs(ir::CollectionInputSpecMap& input_ranges, ir::CollectionTypeMap& input_types);
23+
ExampleIValues GenerateRandomInputs(ir::CollectionInputSpecMap& input_ranges, ir::CollectionTypeMap& input_types);
2424

25-
void runShapeAnalysis(PartitioningCtx* ctx, torch::jit::Block* block, ExampleIValues& ivalues_maps);
25+
void RunShapeAnalysis(PartitioningCtx* ctx, torch::jit::Block* block, ExampleIValues& ivalues_maps);
2626

27-
void segment_graph(PartitioningCtx* ctx, torch::jit::Block* block);
27+
void SegmentGraph(PartitioningCtx* ctx, torch::jit::Block* block);
2828

2929
GraphAndMapping Stitch(PartitioningCtx* ctx, torch::jit::Block* block);
3030

core/partitioning/partitioningctx/PartitioningCtx.cpp

Lines changed: 11 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@ PartitioningCtx::PartitioningCtx(torch::jit::Block* b, PartitioningInfo info)
1515
}
1616

1717
void PartitioningCtx::_load_nodes_into_decision_map(torch::jit::Block* b) {
18-
original_blocks.push_back(b);
18+
if (!b->owningNode() || b->owningNode()->kind() != torch::jit::prim::Loop) {
19+
original_blocks.push_back(b);
20+
}
1921
for (const auto n : b->nodes()) {
2022
if (n->kind() == torch::jit::prim::Constant) {
2123
continue;
@@ -33,60 +35,28 @@ void PartitioningCtx::setNodeExecutorDecision(torch::jit::Node* n, NodeExecutorD
3335
if (iter != node_executor_decision_map.end()) {
3436
prev_decision = iter->second;
3537
}
36-
LOG_GRAPH("Setting node " << util::node_info(n) << " " << decision << " (previously was " << prev_decision << ")");
37-
38-
// NOTE: This is this way due to partitioning.cpp L#134 I dont know if this is what we should do.
38+
LOG_DEBUG("Setting node " << util::node_info(n) << " " << decision << " (previously was " << prev_decision << ")");
3939

40-
auto result = node_executor_decision_map[n] = decision;
41-
return ;
40+
node_executor_decision_map[n] = decision;
41+
return;
4242
}
4343

4444
bool PartitioningCtx::shouldNodeRunInTorch(torch::jit::Node* n) {
4545
auto iter = node_executor_decision_map.find(n);
46-
auto decision = NodeExecutorDecision::kUNKNOWN;
47-
if (iter != node_executor_decision_map.end()) {
48-
decision = iter->second;
49-
}
50-
51-
if (decision == NodeExecutorDecision::kCONVERT || decision == NodeExecutorDecision::kUNKNOWN) {
52-
return false;
53-
} else {
54-
return true;
46+
if (iter == node_executor_decision_map.end()) {
47+
LOG_ERROR("No info about node " << *n << " execution decision status.");
5548
}
49+
return iter->second != NodeExecutorDecision::kCONVERT;
5650
}
5751

5852
bool PartitioningCtx::shouldNodeRunInTensorRT(torch::jit::Node* n) {
59-
auto iter = node_executor_decision_map.find(n);
60-
auto decision = NodeExecutorDecision::kUNKNOWN;
61-
if (iter != node_executor_decision_map.end()) {
62-
decision = iter->second;
63-
}
64-
65-
if (decision == NodeExecutorDecision::kCONVERT) {
66-
return true;
67-
} else {
68-
return false;
69-
}
70-
}
71-
72-
bool PartitioningCtx::isNodeExecutorKnown(torch::jit::Node* n) {
73-
auto iter = node_executor_decision_map.find(n);
74-
auto decision = NodeExecutorDecision::kUNKNOWN;
75-
if (iter != node_executor_decision_map.end()) {
76-
decision = iter->second;
77-
}
78-
79-
if (decision == NodeExecutorDecision::kUNKNOWN) {
80-
return false;
81-
} else {
82-
return true;
83-
}
53+
return !shouldNodeRunInTorch(n);
8454
}
8555

8656
std::vector<torch::jit::Node*> PartitioningCtx::getNodesRunInTorch() {
8757
std::vector<torch::jit::Node*> nodes_run_in_torch;
8858
for (auto i : node_executor_decision_map) {
89-
if (i.second == NodeExecutorDecision::kCONVERT) {
59+
if (i.second != NodeExecutorDecision::kCONVERT) {
9060
nodes_run_in_torch.push_back(i.first);
9161
}
9262
}

core/partitioning/partitioningctx/PartitioningCtx.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,16 +47,18 @@ struct UsageInfo {
4747
struct PartitioningCtx {
4848
// TODO: Make the set a part of settings not stand alone
4949
PartitioningInfo settings;
50+
// records all the original blocks topologically in the module
5051
std::vector<torch::jit::Block*> original_blocks;
52+
// mapping: node=> execution status
5153
NodeExecutorDecisionMap node_executor_decision_map;
54+
// LUT of the segmented blocks for each blocks in the module
5255
std::unordered_map<torch::jit::Block*, PartitionedGraph> partitioned_blocks;
5356
std::unordered_set<std::string> forced_fallback_ops;
5457

5558
PartitioningCtx(torch::jit::Block* b, PartitioningInfo info);
5659
void setNodeExecutorDecision(torch::jit::Node* n, NodeExecutorDecision decision);
5760
bool shouldNodeRunInTorch(torch::jit::Node* n);
5861
bool shouldNodeRunInTensorRT(torch::jit::Node* n);
59-
bool isNodeExecutorKnown(torch::jit::Node* n);
6062
std::vector<torch::jit::Node*> getNodesRunInTorch();
6163

6264
private:

core/partitioning/shape_analysis.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ namespace torch_tensorrt {
99
namespace core {
1010
namespace partitioning {
1111

12-
at::Tensor generateSingleInput(ir::Input& input, c10::optional<at::ScalarType>& type_opt) {
12+
at::Tensor GenerateSingleInput(ir::Input& input, c10::optional<at::ScalarType>& type_opt) {
1313
auto cur_shape = input.input_shape;
1414
std::vector<int64_t> shape;
1515
shape.insert(shape.begin(), std::begin(cur_shape.d), std::begin(cur_shape.d) + cur_shape.nbDims);
@@ -25,7 +25,7 @@ at::Tensor generateSingleInput(ir::Input& input, c10::optional<at::ScalarType>&
2525
return in;
2626
}
2727

28-
std::unordered_map<const torch::jit::Value*, torch::jit::IValue> generateRandomInputs(
28+
std::unordered_map<const torch::jit::Value*, torch::jit::IValue> GenerateRandomInputs(
2929
std::unordered_map<const torch::jit::Value*, std::vector<ir::Input>>& inputs,
3030
std::unordered_map<const torch::jit::Value*, std::vector<c10::optional<at::ScalarType>>>& types) {
3131
// generate random inputs for running pytorch segments
@@ -38,28 +38,28 @@ std::unordered_map<const torch::jit::Value*, torch::jit::IValue> generateRandomI
3838
c10::TypePtr elementType = c10::TensorType::get();
3939
auto generic_list = c10::impl::GenericList(elementType);
4040
for (size_t i = 0; i < input.second.size(); i++) {
41-
auto in = generateSingleInput(input.second[i], types[input.first][i]);
41+
auto in = GenerateSingleInput(input.second[i], types[input.first][i]);
4242
generic_list.push_back(in.clone());
4343
}
4444
ivalue_map[input.first] = c10::IValue(generic_list);
4545
} else if (input.first->type()->kind() == torch::jit::TypeKind::TupleType) {
4646
// create tuple
4747
std::vector<torch::jit::IValue> list;
4848
for (size_t i = 0; i < input.second.size(); i++) {
49-
auto in = generateSingleInput(input.second[i], types[input.first][i]);
49+
auto in = GenerateSingleInput(input.second[i], types[input.first][i]);
5050
list.push_back(in.clone());
5151
}
5252
auto tuple = c10::ivalue::Tuple::create(list); // create tuple ptr
5353
ivalue_map[input.first] = c10::IValue(tuple);
5454
} else {
55-
auto in = generateSingleInput(input.second[0], types[input.first][0]);
55+
auto in = GenerateSingleInput(input.second[0], types[input.first][0]);
5656
ivalue_map[input.first] = in.clone();
5757
}
5858
}
5959
return ivalue_map;
6060
}
6161

62-
void getSegmentsOutputByRunning(
62+
void GetSegmentsOutputByRunning(
6363
SegmentedBlock& seg_block,
6464
std::unordered_map<const torch::jit::Value*, torch::jit::IValue>& ivalues_maps,
6565
const PartitioningInfo& partitioning_info) {
@@ -181,11 +181,11 @@ void getSegmentsOutputByRunning(
181181
seg_block.register_intypes(input_types);
182182
}
183183

184-
void runShapeAnalysis(PartitioningCtx* ctx, torch::jit::Block* block, ExampleIValues& example_tensor_map) {
184+
void RunShapeAnalysis(PartitioningCtx* ctx, torch::jit::Block* block, ExampleIValues& example_tensor_map) {
185185
// register every segment's input shape, and it's running output IValues
186186
for (auto& seg_block : ctx->partitioned_blocks[block]) {
187187
torch::jit::ConstantPooling(seg_block.g());
188-
getSegmentsOutputByRunning(seg_block, example_tensor_map, ctx->settings);
188+
GetSegmentsOutputByRunning(seg_block, example_tensor_map, ctx->settings);
189189
}
190190
return;
191191
}

0 commit comments

Comments
 (0)