Skip to content

Commit eb9e1f6

Browse files
committed
Merge branch 'master' into build_workspace
2 parents bafa675 + edf9ee4 commit eb9e1f6

21 files changed

+266
-20
lines changed

.github/pr-labels.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@
1919
"component: evaluators":
2020
- core/conversion/evaluators/**/*
2121

22+
"component: partitioning":
23+
- core/partitioning/**/*
24+
2225
"component: runtime":
2326
- core/runtime/**/*
2427

core/conversion/conversion.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -326,7 +326,7 @@ void EvaluateLoopBlock(ConversionCtx* ctx, const torch::jit::Node* n) {
326326
MapIValues(ctx, n->outputs(), n->blocks()[0]->inputs(), 0, 1);
327327
for (auto bn : n->blocks()[0]->nodes()) {
328328
if (bn->kind() == torch::jit::prim::Loop) {
329-
EvaluateLoopBlock(ctx, n);
329+
EvaluateLoopBlock(ctx, bn);
330330
} else if (bn->kind() == torch::jit::prim::If) {
331331
EvaluateConditionalBlock(ctx, bn, true);
332332
} else {

core/lowering/passes/module_fallback.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,4 +133,4 @@ void MarkNodesForFallback(std::shared_ptr<torch::jit::Graph>& g, bool delete_del
133133
} // namespace passes
134134
} // namespace lowering
135135
} // namespace core
136-
} // namespace trtorch
136+
} // namespace trtorch

core/lowering/passes/unpack_var.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ void UnpackVar(std::shared_ptr<torch::jit::Graph>& graph) {
4242
torch::jit::SubgraphRewriter var_rewriter;
4343
var_rewriter.RegisterRewritePattern(var_pattern, unpacked_pattern);
4444
var_rewriter.runOnGraph(graph);
45-
LOG_DEBUG("Post unpack var: " << *graph);
45+
LOG_GRAPH("Post unpack var: " << *graph);
4646
}
4747

4848
} // namespace passes

core/partitioning/partitioning.cpp

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
#include <queue>
44
#include "core/conversion/conversion.h"
5+
#include "core/conversion/evaluators/evaluators.h"
56
#include "core/partitioning/shape_analysis.h"
67
#include "torch/csrc/jit/passes/constant_pooling.h"
78
#include "torch/csrc/jit/passes/dead_code_elimination.h"
@@ -114,7 +115,7 @@ std::vector<SegmentedBlock> segmentBlocksWithNonTensorInputs(SegmentedBlock& seg
114115
pytorch_nodes.push_back(n);
115116
prev_non_tensor_outputs = containNonTensorOutputs(n);
116117
} else {
117-
// If pytorch_nodes is not empty, the previous nodes were all tensorrt_nodes. Construct a
118+
// If pytorch_nodes is not empty, the previous nodes were all pytorch_nodes. Construct a
118119
// Pytorch segmented_block and clear the pytorch_nodes list to be later used for new Pytorch segments.
119120
if (!pytorch_nodes.empty()) {
120121
new_seg_blocks.emplace_back(SegmentedBlock::kTorch, pytorch_nodes);
@@ -131,6 +132,7 @@ std::vector<SegmentedBlock> segmentBlocksWithNonTensorInputs(SegmentedBlock& seg
131132
new_seg_blocks.emplace_back(SegmentedBlock::kTorch, pytorch_nodes);
132133
}
133134
}
135+
134136
return std::move(new_seg_blocks);
135137
}
136138

@@ -158,6 +160,8 @@ void resolveNonTensorInputs(PartitionedGraph& segmented_blocks) { // , std::shar
158160
}
159161
}
160162

163+
// For each non-tensor value in the usage_counts map, keep updating the produce_id to the earliest segmented block
164+
// that has/produces it.
161165
for (auto& use : usage_counts) {
162166
// Set the produce_id to the segmented block index that contains/produces this non-tensor torch::jit::Value
163167
if (segmented_blocks[i].contain_raw_value(use.first)) {
@@ -177,9 +181,8 @@ void resolveNonTensorInputs(PartitionedGraph& segmented_blocks) { // , std::shar
177181
// Segmented Blocks with non-tensor inputs will have to be re-segmented as
178182
// TRTorch doesn't support non-tensor inputs for a module.
179183
auto to_inject_blocks = segmentBlocksWithNonTensorInputs(segmented_blocks[first_torch_id]);
180-
segmented_blocks.erase(segmented_blocks.begin() + first_torch_id);
181-
segmented_blocks.insert(
182-
segmented_blocks.begin() + first_torch_id, to_inject_blocks.begin(), to_inject_blocks.end());
184+
auto next_iter = segmented_blocks_list.erase(idx_to_iter[first_torch_id]);
185+
segmented_blocks_list.insert(next_iter, to_inject_blocks.begin(), to_inject_blocks.end());
183186
updated_segments.insert(first_torch_id);
184187
}
185188
}
@@ -258,6 +261,20 @@ void registerSegmentsOutputs(PartitionedGraph& segmented_blocks, torch::jit::Blo
258261
return;
259262
}
260263

264+
bool checkLoopEvaluatable(torch::jit::Node* n) {
265+
bool compile_to_trt = true;
266+
for (auto bn : n->blocks()[0]->nodes()) {
267+
if (bn->kind() == torch::jit::prim::Loop) {
268+
compile_to_trt = compile_to_trt && checkLoopEvaluatable(bn);
269+
} else if (bn->kind() == torch::jit::prim::If) {
270+
compile_to_trt = compile_to_trt && containNonTensorOutputs(bn);
271+
} else {
272+
compile_to_trt = compile_to_trt && core::conversion::evaluators::shouldEvalAtConversionTime(bn);
273+
}
274+
}
275+
return compile_to_trt;
276+
}
277+
261278
std::vector<SegmentedBlock> segment_graph(torch::jit::Block* block, const PartitionInfo& partition_info) {
262279
auto min_block_size = partition_info.min_block_size;
263280
std::unordered_set<std::string> forced_fallback_operators(
@@ -298,6 +315,17 @@ std::vector<SegmentedBlock> segment_graph(torch::jit::Block* block, const Partit
298315
}
299316
segmented_blocks.emplace_back(SegmentedBlock::kTorch, std::vector<torch::jit::Node*>{n});
300317
continue;
318+
} else if (n->kind() == torch::jit::prim::Loop) {
319+
if (!pytorch_nodes.empty()) {
320+
segmented_blocks.emplace_back(SegmentedBlock::kTorch, pytorch_nodes);
321+
pytorch_nodes.clear();
322+
}
323+
if (checkLoopEvaluatable(n)) {
324+
tensorrt_nodes.push_back(n);
325+
} else {
326+
segmented_blocks.emplace_back(SegmentedBlock::kTorch, std::vector<torch::jit::Node*>{n});
327+
}
328+
continue;
301329
}
302330
pytorch_nodes.push_back(n);
303331
}

core/partitioning/shape_analysis.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,9 @@ void getSegmentsOutputByRunning(
5656
for (auto& input : seg_block.raw_inputs()) {
5757
TRTORCH_CHECK(
5858
ivalues_maps.count(input),
59-
"Could not find torch::jit::Value* " << input->debugName() << " in lowering graph for mini graph input.\n");
59+
"Could not find torch::jit::Value* " << input->debugName() << " produced from "
60+
<< util::node_info(input->node())
61+
<< " in lowering graph for mini graph input.\n");
6062
if (input->node()->kind() == torch::jit::prim::Param) {
6163
jit_inputs_ivalues.push_back(ivalues_maps[input]);
6264
} else if (input->type()->isSubtypeOf(torch::jit::TensorType::get())) {

core/runtime/register_trt_op.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,9 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
112112
}
113113

114114
c10::cuda::CUDAStream stream = c10::cuda::getCurrentCUDAStream(inputs[0].device().index());
115+
116+
// nvinfer1::IExecutionContext::enqueue is not thread safe and we need a mutex for it.
117+
std::unique_lock<std::mutex> lock(compiled_engine->mu);
115118
compiled_engine->exec_ctx->enqueueV2(gpu_handles.data(), stream, nullptr);
116119

117120
return outputs;

core/runtime/runtime.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#pragma once
22
#include <map>
33
#include <memory>
4+
#include <mutex>
45
#include <utility>
56
#include "ATen/core/function_schema.h"
67
#include "NvInfer.h"
@@ -47,6 +48,7 @@ struct TRTEngine : torch::CustomClassHolder {
4748
std::shared_ptr<nvinfer1::IExecutionContext> exec_ctx;
4849
std::pair<uint64_t, uint64_t> num_io;
4950
std::string name;
51+
std::mutex mu;
5052
CudaDevice device_info;
5153

5254
std::unordered_map<uint64_t, uint64_t> in_binding_map;

docs/_notebooks/Resnet50-example.html

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -725,7 +725,7 @@
725725
</div>
726726
</div>
727727
<p>
728-
<img alt="d34deeb0bd04450db415e7ad8573b82a" src="http://developer.download.nvidia.com/compute/machine-learning/frameworks/nvidia_logo.png"/>
728+
<img alt="1794a581632146b3a0c2a5cea9db9870" src="http://developer.download.nvidia.com/compute/machine-learning/frameworks/nvidia_logo.png"/>
729729
</p>
730730
<section id="TRTorch-Getting-Started---ResNet-50">
731731
<h1 id="notebooks-resnet50-example--page-root">

docs/_notebooks/lenet-getting-started.html

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -819,7 +819,7 @@
819819
</div>
820820
</div>
821821
<p>
822-
<img alt="4ad32834008942b1a13a55d1a56e70b2" src="http://developer.download.nvidia.com/compute/machine-learning/frameworks/nvidia_logo.png"/>
822+
<img alt="f6316bc6a1b54cada66e418a5317073b" src="http://developer.download.nvidia.com/compute/machine-learning/frameworks/nvidia_logo.png"/>
823823
</p>
824824
<section id="TRTorch-Getting-Started---LeNet">
825825
<h1 id="notebooks-lenet-getting-started--page-root">

0 commit comments

Comments
 (0)