Skip to content

Commit ded366d

Browse files
committed
Merge branch 'torch_tensorrt_rebrand' into docs_refactor
2 parents 8d9722d + 05b8376 commit ded366d

27 files changed

+276
-489
lines changed

core/compiler.cpp

Lines changed: 41 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -297,46 +297,48 @@ void MapInputsAndDetermineDTypes(
297297
cfg.convert_info.inputs = std::move(ir::associate_specs_with_inputs(g, cfg.inputs, static_params));
298298

299299
for (auto& in : g->inputs()) {
300-
auto est_type_opt = first_use_type_map.find(in)->second;
301-
ir::Input& spec = cfg.convert_info.inputs.find(in)->second;
302-
if (est_type_opt && !spec.dtype_is_user_defined) {
303-
// If we can calculate the type from the graph and the type was not defined by the user then use the calculated
304-
// type
305-
LOG_INFO(
306-
"Since input type is not explicitly defined, infering using first tensor calculation\n Found input "
307-
<< in->debugName() << " has type " << est_type_opt.value()
308-
<< ". If this is incorrect explicitly set dtype for input and file a bug");
309-
spec.dtype = util::ScalarTypeToTRTDataType(est_type_opt.value());
310-
} else if (!est_type_opt && !spec.dtype_is_user_defined) {
311-
// If we cannot calculate the type and the user did not define the type, then default to FP32
312-
LOG_WARNING(
313-
"Cannot infer input type from calcuations in graph for input "
314-
<< in->debugName() << ". Assuming it is Float32. If not, specify input type explicity");
315-
spec.dtype = nvinfer1::DataType::kFLOAT;
316-
} else if (spec.dtype_is_user_defined && cfg.partition_info.enabled) {
317-
if (!est_type_opt) {
318-
LOG_INFO("Cannot infer input tensor dtype in graph, unable to verify user input dtype settings");
319-
} else {
320-
if (util::TRTDataTypeToScalarType(cfg.convert_info.inputs.find(in)->second.dtype) != est_type_opt.value()) {
321-
std::stringstream ss;
322-
ss << "For input " << in->debugName() << ", found user specified input dtype as ";
323-
ss << cfg.convert_info.inputs.find(in)->second.dtype;
324-
ss << ", however when inspecting the graph, the input type expected was inferred to be ";
325-
ss << est_type_opt.value() << std::endl;
326-
ss << "The compiler is going to use the user setting " << cfg.convert_info.inputs.find(in)->second.dtype;
327-
ss << "\nThis conflict may cause an error at runtime due to partial compilation being enabled and therefore\n";
328-
ss << "compatibility with PyTorch's data type convention is required.\n";
329-
ss << "If you do indeed see errors at runtime either:\n";
330-
ss << "- Remove the dtype spec for " << in->debugName() << std::endl;
331-
ss << "- Disable partial compilation by setting require_full_compilation to True";
332-
auto warn_str = ss.str();
333-
LOG_WARNING(warn_str);
334-
// Overwrite type map with user settings
335-
first_use_type_map[in] = {util::TRTDataTypeToScalarType(cfg.convert_info.inputs.find(in)->second.dtype)};
300+
if (static_params.find(in) == static_params.end()) {
301+
ir::Input& spec = cfg.convert_info.inputs.find(in)->second;
302+
auto est_type_opt = first_use_type_map.find(in)->second;
303+
if (est_type_opt && !spec.dtype_is_user_defined) {
304+
// If we can calculate the type from the graph and the type was not defined by the user then use the calculated
305+
// type
306+
LOG_INFO(
307+
"Since input type is not explicitly defined, infering using first tensor calculation\n Found input "
308+
<< in->debugName() << " has type " << est_type_opt.value()
309+
<< ". If this is incorrect explicitly set dtype for input and file a bug");
310+
spec.dtype = util::ScalarTypeToTRTDataType(est_type_opt.value());
311+
} else if (!est_type_opt && !spec.dtype_is_user_defined) {
312+
// If we cannot calculate the type and the user did not define the type, then default to FP32
313+
LOG_WARNING(
314+
"Cannot infer input type from calcuations in graph for input "
315+
<< in->debugName() << ". Assuming it is Float32. If not, specify input type explicity");
316+
spec.dtype = nvinfer1::DataType::kFLOAT;
317+
} else if (spec.dtype_is_user_defined && cfg.partition_info.enabled) {
318+
if (!est_type_opt) {
319+
LOG_INFO("Cannot infer input tensor dtype in graph, unable to verify user input dtype settings");
320+
} else {
321+
if (util::TRTDataTypeToScalarType(cfg.convert_info.inputs.find(in)->second.dtype) != est_type_opt.value()) {
322+
std::stringstream ss;
323+
ss << "For input " << in->debugName() << ", found user specified input dtype as ";
324+
ss << cfg.convert_info.inputs.find(in)->second.dtype;
325+
ss << ", however when inspecting the graph, the input type expected was inferred to be ";
326+
ss << est_type_opt.value() << std::endl;
327+
ss << "The compiler is going to use the user setting " << cfg.convert_info.inputs.find(in)->second.dtype;
328+
ss << "\nThis conflict may cause an error at runtime due to partial compilation being enabled and therefore\n";
329+
ss << "compatibility with PyTorch's data type convention is required.\n";
330+
ss << "If you do indeed see errors at runtime either:\n";
331+
ss << "- Remove the dtype spec for " << in->debugName() << std::endl;
332+
ss << "- Disable partial compilation by setting require_full_compilation to True";
333+
auto warn_str = ss.str();
334+
LOG_WARNING(warn_str);
335+
// Overwrite type map with user settings
336+
first_use_type_map[in] = {util::TRTDataTypeToScalarType(cfg.convert_info.inputs.find(in)->second.dtype)};
337+
}
336338
}
339+
} else {
340+
// The user defined the type so no changes are necessary
337341
}
338-
} else {
339-
// The user defined the type so no changes are necessary
340342
}
341343
}
342344
}
@@ -375,7 +377,7 @@ std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, std::
375377

376378
auto engine = conversion::ConvertBlockToEngine(g->block(), cfg.convert_info, static_params);
377379

378-
return std::move(engine);
380+
return engine;
379381
}
380382

381383
torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg) {

core/compiler.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,4 +31,4 @@ torch::jit::script::Module EmbedEngineInNewModule(const std::string& engine, run
3131
void set_device(const int gpu_id);
3232

3333
} // namespace core
34-
} // namespace trtorch
34+
} // namespace torch_tensorrt

core/ir/StaticParams.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ StaticParams get_static_params(c10::ArrayRef<torch::jit::Value*> inputs, std::ve
2020
TRTORCH_CHECK(
2121
static_params.size() == params.size(),
2222
"Graph parameter parsing failed, mismatched number of static parameters and IValues")
23-
return std::move(static_params);
23+
return static_params;
2424
}
2525

2626
} // namespace ir

core/ir/ir.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ InputSpecMap pair_input_vals_with_specs(std::vector<const torch::jit::Value*> va
2424
LOG_DEBUG("Paring " << i << ": " << vals[i]->debugName() << " : " << specs[i]);
2525
a.insert({vals[i], specs[i]});
2626
}
27-
return std::move(a);
27+
return a;
2828
}
2929

3030
std::vector<const torch::jit::Value*> get_tensor_inputs(

core/partitioning/partitioning.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ std::vector<SegmentedBlock> segmentBlocksWithNonTensorInputs(SegmentedBlock& seg
133133
}
134134
}
135135

136-
return std::move(new_seg_blocks);
136+
return new_seg_blocks;
137137
}
138138

139139
void resolveNonTensorInputs(PartitionedGraph& segmented_blocks) { // , std::shared_ptr<torch::jit::Graph> g
@@ -385,7 +385,7 @@ PartitionedGraph segment_graph(torch::jit::Block* block, const PartitionInfo& pa
385385
finalize_block(segmented_blocks, SegmentedBlock::kTorch, in_prog_pyt_blk_nodes);
386386
}
387387

388-
return std::move(segmented_blocks);
388+
return segmented_blocks;
389389
}
390390

391391
PartitionedGraph Partition(

core/util/jit_util.cpp

Lines changed: 0 additions & 113 deletions
This file was deleted.

docker/Dockerfile

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
1-
ARG BASE=21.06
1+
ARG BASE=21.09
22
ARG BASE_IMG=nvcr.io/nvidia/pytorch:${BASE}-py3
33
FROM ${BASE_IMG} as base
44

5-
FROM base as trtorch-builder-base
5+
FROM base as torch-tensorrt-builder-base
66

7-
# Removing any bazel or trtorch pre-installed from the base image
7+
# Removing any bazel or torch-tensorrt pre-installed from the base image
88

9-
RUN rm -rf /opt/pytorch/trtorch /usr/bin/bazel
9+
RUN rm -rf /opt/torch-tensorrt /usr/bin/bazel
1010

1111
RUN apt-get update && apt-get install --no-install-recommends -y curl gnupg
1212
RUN curl -fsSL https://bazel.build/bazel-release.pub.gpg | gpg --dearmor > /etc/apt/trusted.gpg.d/bazel.gpg
@@ -20,30 +20,32 @@ RUN cp /usr/lib/x86_64-linux-gnu/libnvinfer.so /usr/lib/x86_64-linux-gnu/libnvin
2020

2121
RUN apt-get update && apt-get install -y --no-install-recommends locales ninja-build && rm -rf /var/lib/apt/lists/* && locale-gen en_US.UTF-8
2222

23-
FROM trtorch-builder-base as trtorch-builder
23+
FROM torch-tensorrt-builder-base as torch-tensorrt-builder
2424

25-
COPY . /workspace/trtorch/src
26-
WORKDIR /workspace/trtorch/src
27-
RUN cp ./docker/WORKSPACE.cu.docker WORKSPACE
25+
COPY . /workspace/torch_tensorrt/src
26+
WORKDIR /workspace/torch_tensorrt/src
27+
RUN cp ./docker/WORKSPACE.docker WORKSPACE
2828

2929
# This script builds both libtrtorch bin/lib/include tarball and the Pythin wheel, in dist/
3030
RUN ./docker/dist-build.sh
3131

32-
FROM base as trtorch
32+
FROM base as torch-tensorrt
3333

3434
# copy source repo
35-
COPY . /workspace/trtorch
36-
COPY --from=trtorch-builder /workspace/trtorch/src/dist/ .
37-
RUN patch -u /opt/conda/lib/python3.8/site-packages/pytorch_quantization/nn/modules/tensor_quantizer.py -i /workspace/trtorch/docker/qat.patch
35+
COPY . /workspace/torch_tensorrt
36+
COPY --from=torch-tensorrt-builder /workspace/torch_tensorrt/src/dist/ .
37+
RUN patch -u /opt/conda/lib/python3.8/site-packages/pytorch_quantization/nn/modules/tensor_quantizer.py -i /workspace/torch_tensorrt/docker/qat.patch
3838
RUN conda init bash
3939

4040
RUN pip3 install ipywidgets --trusted-host pypi.org --trusted-host pypi.python.org --trusted-host=files.pythonhosted.org
4141
RUN jupyter nbextension enable --py widgetsnbextension
4242

43-
RUN mkdir -p /opt/trtorch && tar xvf libtrtorch.tar.gz --strip-components 2 -C /opt/trtorch --exclude=LICENSE && pip3 install *.whl && rm -fr /workspace/trtorch/dist/*
43+
RUN pip3 install *.whl && rm -fr /workspace/torch_tensorrt/dist/* *.whl
4444

45-
ENV LD_LIBRARY_PATH /opt/conda/lib/python3.8/site-packages/torch/lib:/opt/trtorch/lib:${LD_LIBRARY_PATH}
46-
ENV PATH /opt/trtorch/bin:${PATH}
45+
ENV LD_LIBRARY_PATH /opt/conda/lib/python3.8/site-packages/torch/lib:/opt/conda/lib/python3.8/site-packages/torch_tensorrt/lib:${LD_LIBRARY_PATH}
46+
ENV PATH /opt/conda/lib/python3.8/site-packages/torch_tensorrt/bin:${PATH}
47+
48+
WORKDIR /workspace
49+
RUN mv /workspace/torch_tensorrt /opt/pytorch/torch_tensorrt
4750

48-
WORKDIR /workspace/trtorch/
4951
CMD /bin/bash

0 commit comments

Comments
 (0)