@@ -73,7 +73,6 @@ void SetExplicitFallbackNodes(PartitioningCtx* ctx, torch::jit::Block* block) {
73
73
// Set the rest nodes to TensorRt
74
74
ctx->setNodeExecutorDecision (n, NodeExecutorDecision::kCONVERT );
75
75
}
76
-
77
76
}
78
77
return ;
79
78
}
@@ -103,7 +102,8 @@ void SetNonTensorConnectedNodes(PartitioningCtx* ctx, std::vector<torch::jit::No
103
102
if (!isTensor (output)) {
104
103
for (auto use : output->uses ()) {
105
104
auto node = use.user ;
106
- if (node->kind () != torch::jit::prim::Constant && ctx->shouldNodeRunInTensorRT (node)) {
105
+ if (node->kind () != torch::jit::prim::Constant && node->kind () != torch::jit::prim::Return &&
106
+ ctx->shouldNodeRunInTensorRT (node)) {
107
107
ctx->setNodeExecutorDecision (node, NodeExecutorDecision::kNON_TENSOR );
108
108
q.push (node);
109
109
}
@@ -175,7 +175,7 @@ bool isModifyingNodes(torch::jit::Node* node, torch::jit::Value* val) {
175
175
return false ;
176
176
}
177
177
178
- std::vector<torch::jit::Node*> findModifyingNodes (
178
+ std::vector<torch::jit::Node*> FindModifyingNodes (
179
179
torch::jit::Value* val,
180
180
const std::unordered_set<torch::jit::Node*>& seg_block_nodes) {
181
181
std::vector<torch::jit::Node*> modifying_nodes;
@@ -192,7 +192,7 @@ std::vector<torch::jit::Node*> findModifyingNodes(
192
192
}
193
193
194
194
// this function is only used when a TRT segment produces nonTensor values which are used by later TRT segment
195
- std::vector<torch::jit::Node*> getDependencyNodes (
195
+ std::vector<torch::jit::Node*> GetDependencyNodes (
196
196
const std::vector<torch::jit::Value*>& vals,
197
197
const SegmentedBlock& seg_block) {
198
198
// get all nodes in the segmentedblock
@@ -208,7 +208,7 @@ std::vector<torch::jit::Node*> getDependencyNodes(
208
208
auto node = cur_val->node ();
209
209
if (node->kind () != torch::jit::prim::Constant && !visited.count (node)) {
210
210
visited.insert (node);
211
- auto modifying_nodes = findModifyingNodes (cur_val, seg_block_nodes);
211
+ auto modifying_nodes = FindModifyingNodes (cur_val, seg_block_nodes);
212
212
stk.insert (stk.end (), modifying_nodes.rbegin (), modifying_nodes.rend ());
213
213
stk.push_back (node);
214
214
for (auto input : node->inputs ()) {
@@ -222,7 +222,7 @@ std::vector<torch::jit::Node*> getDependencyNodes(
222
222
return stk;
223
223
}
224
224
225
- void resolveTRTNonTensorInputs (PartitioningCtx* ctx, torch::jit::Block* block) {
225
+ void ResolveTRTNonTensorInputs (PartitioningCtx* ctx, torch::jit::Block* block) {
226
226
// if a TRT segment has nonTensor Inputs, the nodes that produce this nonTensor Inputs must in another TensorRT engine
227
227
// because we have already found the interface between Torch and TRT in segmentation phase
228
228
// what we do here is just find the dependency nodes of the TRT segments that have nonTensor inputs
@@ -236,16 +236,19 @@ void resolveTRTNonTensorInputs(PartitioningCtx* ctx, torch::jit::Block* block) {
236
236
}
237
237
}
238
238
if (!inputs_to_resolve.empty ()) {
239
- std::vector<torch::jit::Node*> dependency_nodes = getDependencyNodes (inputs_to_resolve, cur_partitioned_block[i]);
239
+ std::vector<torch::jit::Node*> dependency_nodes =
240
+ GetDependencyNodes (inputs_to_resolve, cur_partitioned_block[i]);
240
241
dependency_nodes.insert (
241
- dependency_nodes.end (), cur_partitioned_block[i].raw_nodes ().begin (), cur_partitioned_block[i].raw_nodes ().end ());
242
+ dependency_nodes.end (),
243
+ cur_partitioned_block[i].raw_nodes ().begin (),
244
+ cur_partitioned_block[i].raw_nodes ().end ());
242
245
cur_partitioned_block[i] = SegmentedBlock (SegmentedBlock::kTensorRT , dependency_nodes);
243
246
}
244
247
}
245
248
}
246
249
}
247
250
248
- void registerSegmentsOutputs (PartitioningCtx* ctx, torch::jit::Block* block) {
251
+ void RegisterSegmentsOutputs (PartitioningCtx* ctx, torch::jit::Block* block) {
249
252
// find the corresponding raw values in original global graph for this segmented block's inputs/outputs
250
253
PartitionedGraph& cur_partitioned_block = ctx->partitioned_blocks [block];
251
254
auto cmp = [](torch::jit::Value* a, torch::jit::Value* b) { return a->unique () < b->unique (); };
@@ -331,21 +334,46 @@ void finalizeNewBlock(
331
334
LOG_DEBUG (g.back ());
332
335
}
333
336
337
+ void SetNodeExecutorLUT (PartitioningCtx* ctx, torch::jit::Block* block) {
338
+ // First, find all the explicit fallback nodes that should run in Torch:
339
+ // 1. nodes that are unsupported
340
+ // 2. nodes that the user specifies to run in torch
341
+ // 3. nodes that the user specifies the module containing this op to run in torch
342
+ // At the same time, set all the rest nodes to NodeExecutorDecision::kCONVERT
343
+ SetExplicitFallbackNodes (ctx, block);
344
+
345
+ // Second, check if there is nonTensor input/output for the block, if there is, then fallback the nodes that
346
+ // consume/produce this nonTensor value
347
+ SetInputsOutputsConnectedNodes (ctx, block);
348
+
349
+ // Third, for fallback nodes, if it consumes any NonTensor inputs, then the nodes that produce this
350
+ // input should also fallback. Similarly, if it produces any NonTensor outputs, then the nodes
351
+ // that consume this output should also fallback
352
+ auto cur_fallback_nodes = ctx->getNodesRunInTorch ();
353
+ SetNonTensorConnectedNodes (ctx, cur_fallback_nodes);
354
+
355
+ // Finally, check if all current tensorrt blocks satisfy the min_block_size requirement.
356
+ // We need to traverse the whole graph many times here
357
+ SetMinBlockFallbackNodes (ctx, block);
358
+ }
359
+
334
360
void SegmentGraph (PartitioningCtx* ctx, torch::jit::Block* block) {
361
+ // Find all the fallback nodes and build execution decision LUT for all nodes
362
+ SetNodeExecutorLUT (ctx, block);
363
+
335
364
auto nodes = block->nodes ();
336
365
337
366
// segment the nodes
338
367
PartitionedGraph segmented_blocks;
339
368
340
369
std::vector<torch::jit::Node*> in_prog_trt_blk_nodes, in_prog_pyt_blk_nodes;
341
370
for (const auto n : nodes) {
342
-
343
371
// Skip constant nodes as they are resources for both kinds of modules
344
372
if (n->kind () == torch::jit::prim::Constant) {
345
373
continue ;
346
374
}
347
375
// the outputs of trt subgraph shouldn't be collections
348
- if (! ctx->shouldNodeRunInTorch (n)) {
376
+ if (ctx->shouldNodeRunInTensorRT (n)) {
349
377
in_prog_trt_blk_nodes.push_back (n);
350
378
351
379
// If there is an active PyTorch block and we have passed the threshold for a valid TRT
@@ -410,65 +438,26 @@ void SegmentGraph(PartitioningCtx* ctx, torch::jit::Block* block) {
410
438
return ;
411
439
}
412
440
413
- void SetNodeExecutorLUT (PartitioningCtx* ctx, torch::jit::Block* block) {
414
- // First, find all the explicit fallback nodes that should run in Torch:
415
- // 1. nodes that are unsupported
416
- // 2. nodes that the user specifies to run in torch
417
- // 3. nodes that the user specifies the module containing this op to run in torch
418
- // At the same time, set all the rest nodes to NodeExecutorDecision::kCONVERT
419
- SetExplicitFallbackNodes (ctx, block);
420
-
421
- // Second, check if there is nonTensor input/output for the block, if there is, then fallback the nodes that
422
- // consume/produce this nonTensor value
423
- SetInputsOutputsConnectedNodes (ctx, block);
424
-
425
- // Third, for fallback nodes, if it consumes any NonTensor inputs, then the nodes that produce this
426
- // input should also fallback. Similarly, if it produces any NonTensor outputs, then the nodes
427
- // that consume this output should also fallback
428
- auto cur_fallback_nodes = ctx->getNodesRunInTorch ();
429
- SetNonTensorConnectedNodes (ctx, cur_fallback_nodes);
430
-
431
- // Finally, check if all current tensorrt blocks satisfy the min_block_size requirement.
432
- // We need to traverse the whole graph many times here
433
- SetMinBlockFallbackNodes (ctx, block);
434
- }
435
-
436
441
void Partition (PartitioningCtx* ctx, ExampleIValues& example_tensor_map) {
437
442
LOG_DEBUG (ctx->settings );
438
443
439
444
// Go through all the blocks to do the partitioning
440
445
for (torch::jit::Block* block : ctx->original_blocks ) {
441
-
442
- // Find all the fallback nodes and build execution decision LUT for all nodes
443
- SetNodeExecutorLUT (ctx, block);
444
-
445
446
// segment lowering global graph into blocks
446
447
SegmentGraph (ctx, block);
447
448
448
449
// It's possible that some TensorRT blocks have nonTensor inputs/output because they are interleaved by Torch blocks
449
450
// resolve nonTensor inputs/outputs
450
- resolveTRTNonTensorInputs (ctx, block);
451
+ ResolveTRTNonTensorInputs (ctx, block);
451
452
452
453
// register input/output torch::jit::Value for segmented graphs
453
454
LOG_DEBUG (" Registering input/output torch::jit::Value for segmented graphs" );
454
- registerSegmentsOutputs (ctx, block);
455
+ RegisterSegmentsOutputs (ctx, block);
455
456
456
- for (auto &i : ctx->partitioned_blocks [block]) {
457
- LOG_DEBUG (i);
458
- }
459
457
460
458
// run shape analysis on each segmented block
461
- runShapeAnalysis (ctx, block, example_tensor_map);
462
-
459
+ RunShapeAnalysis (ctx, block, example_tensor_map);
463
460
}
464
-
465
-
466
-
467
- // for (uint64_t i = 0; i < ctx->blocks.size(); i++) {
468
- // ctx->blocks[i].update_id(i);
469
- // }
470
-
471
-
472
461
}
473
462
474
463
} // namespace partitioning
0 commit comments