Skip to content

Commit cbe04cb

Browse files
authored
Merge branch 'master' into dyn_shapes
2 parents 6d0b0f6 + e3b9929 commit cbe04cb

File tree

128 files changed

+786
-146
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

128 files changed

+786
-146
lines changed

core/conversion/conversionctx/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ cc_library(
2121
deps = [
2222
"@tensorrt//:nvinfer",
2323
"//core/util:prelude",
24+
"//core/ir",
2425
] + select({
2526
":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"],
2627
"//conditions:default": ["@libtorch//:libtorch"],

core/conversion/conversionctx/ConversionCtx.h

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,28 +9,21 @@
99
#include "torch/csrc/jit/ir/ir.h"
1010

1111
#include <cuda_runtime.h>
12+
#include "core/ir/ir.h"
1213
#include "core/util/prelude.h"
1314

1415
namespace torch_tensorrt {
1516
namespace core {
1617
namespace conversion {
1718

18-
struct Device {
19-
nvinfer1::DeviceType device_type;
20-
int64_t gpu_id;
21-
int64_t dla_core;
22-
bool allow_gpu_fallback;
23-
Device() : device_type(nvinfer1::DeviceType::kGPU), gpu_id(0), dla_core(0), allow_gpu_fallback(false) {}
24-
};
25-
2619
struct BuilderSettings {
2720
std::set<nvinfer1::DataType> enabled_precisions = {};
2821
bool sparse_weights = false;
2922
bool disable_tf32 = false;
3023
bool refit = false;
3124
bool debug = false;
3225
bool truncate_long_and_double = false;
33-
Device device;
26+
ir::Device device;
3427
nvinfer1::EngineCapability capability = TRT_ENGINE_CAPABILITY_STANDARD;
3528
nvinfer1::IInt8Calibrator* calibrator = nullptr;
3629
uint64_t num_avg_timing_iters = 1;

core/conversion/converters/impl/element_wise.cpp

Lines changed: 13 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -166,11 +166,11 @@ auto element_wise_registrations TORCHTRT_UNUSED =
166166
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
167167
// Should implement self - alpha * other
168168
auto self = args[0].ITensorOrFreeze(ctx);
169-
auto scalar = args[2].unwrapToScalar().to<float>();
170169
auto other = args[1].ITensorOrFreeze(ctx);
170+
auto scalar = args[2].unwrapToScalar();
171171

172-
if (1 != scalar) {
173-
auto alphaTensor = tensor_to_const(ctx, torch::tensor({scalar}));
172+
if (1 != scalar.to<float>()) {
173+
auto alphaTensor = scalar_to_tensor(ctx, scalar);
174174
auto scaleLayer = add_elementwise(
175175
ctx,
176176
nvinfer1::ElementWiseOperation::kPROD,
@@ -214,11 +214,11 @@ auto element_wise_registrations TORCHTRT_UNUSED =
214214
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
215215
// Should implement self - alpha * other
216216
auto self = args[0].ITensorOrFreeze(ctx);
217-
auto scalar = args[2].unwrapToScalar().to<float>();
218217
auto other = args[1].ITensorOrFreeze(ctx);
218+
auto scalar = args[2].unwrapToScalar();
219219

220-
if (1 != scalar) {
221-
auto alphaTensor = tensor_to_const(ctx, torch::tensor({scalar}));
220+
if (1 != scalar.to<float>()) {
221+
auto alphaTensor = scalar_to_tensor(ctx, scalar);
222222
auto scaleLayer = add_elementwise(
223223
ctx,
224224
nvinfer1::ElementWiseOperation::kPROD,
@@ -351,8 +351,7 @@ auto element_wise_registrations TORCHTRT_UNUSED =
351351
{"aten::div.Scalar(Tensor self, Scalar other) -> (Tensor)",
352352
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
353353
auto self = args[0].ITensorOrFreeze(ctx);
354-
auto otherScalar = args[1].unwrapToScalar().to<float>();
355-
auto other = tensor_to_const(ctx, torch::tensor({otherScalar}));
354+
auto other = scalar_to_tensor(ctx, args[1].unwrapToScalar());
356355
auto div = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kDIV, self, other, util::node_info(n));
357356
TORCHTRT_CHECK(div, "Unable to create div layer from node: " << *n);
358357

@@ -381,8 +380,7 @@ auto element_wise_registrations TORCHTRT_UNUSED =
381380
{"aten::div_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)",
382381
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
383382
auto self = args[0].ITensorOrFreeze(ctx);
384-
auto otherScalar = args[1].unwrapToScalar().to<float>();
385-
auto other = tensor_to_const(ctx, torch::tensor({otherScalar}));
383+
auto other = scalar_to_tensor(ctx, args[1].unwrapToScalar());
386384
auto div = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kDIV, self, other, util::node_info(n));
387385
TORCHTRT_CHECK(div, "Unable to create div layer from node: " << *n);
388386

@@ -481,18 +479,12 @@ auto element_wise_registrations TORCHTRT_UNUSED =
481479
{"aten::ne.Scalar(Tensor self, Scalar other) -> (Tensor)",
482480
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
483481
auto self = args[0].ITensorOrFreeze(ctx);
484-
auto scalar = args[1].unwrapToScalar();
485-
nvinfer1::ITensor* scalar_tensor;
486-
if (self->getType() == nvinfer1::DataType::kFLOAT || self->getType() == nvinfer1::DataType::kHALF) {
487-
scalar_tensor = tensor_to_const(ctx, torch::tensor({scalar.to<float>()}));
488-
} else {
489-
scalar_tensor = tensor_to_const(ctx, torch::tensor({scalar.to<int>()}));
490-
}
482+
auto other = scalar_to_tensor(ctx, args[1].unwrapToScalar());
491483
auto equal = add_elementwise(
492484
ctx,
493485
nvinfer1::ElementWiseOperation::kEQUAL,
494486
self,
495-
scalar_tensor,
487+
other,
496488
util::node_info(n) + std::string("is_equal"));
497489
TORCHTRT_CHECK(equal, "Unable to create elementwise equal layer from node: " << *n);
498490
// XOR with ones negates and produces not_equal result
@@ -534,8 +526,7 @@ auto element_wise_registrations TORCHTRT_UNUSED =
534526
{"aten::pow.Tensor_Scalar(Tensor self, Scalar exponent) -> (Tensor)",
535527
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
536528
auto self = args[0].ITensorOrFreeze(ctx);
537-
auto exponentScalar = args[1].unwrapToScalar().to<float>();
538-
auto exponent = tensor_to_const(ctx, torch::tensor({exponentScalar}));
529+
auto exponent = scalar_to_tensor(ctx, args[1].unwrapToScalar());
539530
auto pow =
540531
add_elementwise(ctx, nvinfer1::ElementWiseOperation::kPOW, self, exponent, util::node_info(n));
541532
TORCHTRT_CHECK(pow, "Unable to create Power layer from node: " << *n);
@@ -681,9 +672,9 @@ auto element_wise_registrations TORCHTRT_UNUSED =
681672
{"aten::eq.Scalar(Tensor self, Scalar other) -> (Tensor)",
682673
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
683674
auto self = args[0].ITensorOrFreeze(ctx);
684-
auto otherScalar = args[1].unwrapToScalar().to<float>();
685-
auto other = tensor_to_const(ctx, torch::tensor({otherScalar}));
675+
auto other = scalar_to_tensor(ctx, args[1].unwrapToScalar());
686676
if (self->getType() == nvinfer1::DataType::kBOOL) {
677+
auto otherScalar = args[1].unwrapToScalar().to<float>();
687678
if (otherScalar == 0 || otherScalar == 1) {
688679
LOG_DEBUG("Since input tensor is type bool, casting input tensor and scalar to int32");
689680
other = castITensor(ctx, other, nvinfer1::DataType::kINT32);

core/ir/ir.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,14 @@ enum class ShapeMode {
1717
kMAX,
1818
};
1919

20+
struct Device {
21+
nvinfer1::DeviceType device_type;
22+
int64_t gpu_id;
23+
int64_t dla_core;
24+
bool allow_gpu_fallback;
25+
Device() : device_type(nvinfer1::DeviceType::kGPU), gpu_id(0), dla_core(0), allow_gpu_fallback(false) {}
26+
};
27+
2028
struct Input : torch::CustomClassHolder {
2129
Input(){};
2230
Input(

core/lowering/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ cc_library(
2424
deps = [
2525
"//core/lowering/passes",
2626
"//core/util:prelude",
27+
"//core/ir",
2728
] + select({
2829
":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"],
2930
"//conditions:default": ["@libtorch//:libtorch"],

core/lowering/CMakeLists.txt

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ set(HEADER_FILES
1515
target_sources(${lib_name}
1616
PRIVATE
1717
${CXX_SRCS}
18+
PUBLIC
19+
$<TARGET_OBJECTS:core_ir>
1820
$<TARGET_OBJECTS:core_util>
1921
)
2022

@@ -25,8 +27,9 @@ target_include_directories(${lib_name}
2527

2628
target_link_libraries(${lib_name}
2729
PUBLIC
30+
TensorRT::nvinfer
2831
torch
29-
PRIVATE
32+
core_ir
3033
core_util
3134
)
3235

core/lowering/lowering.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ void LowerBlock(torch::jit::Block* b) {
2626
DropUnusedNodes(b);
2727
}
2828

29-
void LowerGraph(std::shared_ptr<torch::jit::Graph>& g, LowerInfo lower_info) {
29+
void LowerGraph(std::shared_ptr<torch::jit::Graph>& g, std::vector<torch::jit::IValue>& params, LowerInfo lower_info) {
3030
torch::jit::EliminateRedundantGuards(g);
3131
torch::jit::RemoveListMutation(g);
3232
torch::jit::RemoveTensorMutation(g);
@@ -70,6 +70,11 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g, LowerInfo lower_info) {
7070
passes::SiluToSigmoidMultipication(g);
7171
passes::RemoveSingleUse0DTensors(g);
7272
passes::RemoveUnnecessaryCasts(g);
73+
passes::UnpackAndCastMaskedFill(g, lower_info.getGPUDeviceString());
74+
passes::UnpackAndCastNumToTensor(g, lower_info.getGPUDeviceString());
75+
passes::UnpackAndCastFull(g, lower_info.getGPUDeviceString());
76+
passes::ReplaceScalarImplicit(g);
77+
passes::RewriteInputsWithParams(g, params);
7378
LOG_GRAPH(*g);
7479
}
7580

@@ -103,7 +108,7 @@ std::pair<std::shared_ptr<torch::jit::Graph>, std::vector<torch::jit::IValue>> L
103108
// In quantization aware trained (QAT) models, weights are passed through quantize and
104109
// dequantize nodes which should not be folded. So unfreeze_module is set to True for QAT models.
105110
LOG_GRAPH("Torch-TensorRT.TorchScript Graph Lowering");
106-
lowering::LowerGraph(graph_and_ivalues.first, lower_info);
111+
lowering::LowerGraph(graph_and_ivalues.first, graph_and_ivalues.second, lower_info);
107112

108113
// Is this necessary?
109114
// lowering::LowerBlock(g->block());

core/lowering/lowering.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#pragma once
22
#include <memory>
3+
#include "core/ir/ir.h"
34
#include "torch/csrc/jit/ir/ir.h"
45

56
namespace torch_tensorrt {
@@ -15,8 +16,13 @@ struct LowerInfo {
1516
// Since these QDQ nodes will be identical as they share same input, one of them is eliminated due to CSE lowering
1617
// pass. Disable this in order to not disturb TensorRT's QAT optimizations.
1718
bool disable_cse = false;
19+
ir::Device target_device;
1820
std::vector<std::string> forced_fallback_modules;
1921
friend std::ostream& operator<<(std::ostream& os, const LowerInfo& l);
22+
23+
std::string getGPUDeviceString() {
24+
return "cuda:" + std::to_string(target_device.gpu_id);
25+
};
2026
};
2127

2228
void LowerBlock(torch::jit::Block* b);

core/lowering/passes/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ cc_library(
1414
name = "passes",
1515
srcs = [
1616
"convNd_to_convolution.cpp",
17+
"device_casting.cpp",
1718
"exception_elimination.cpp",
1819
"fuse_addmm_branches.cpp",
1920
"linear_to_addmm.cpp",
@@ -27,6 +28,7 @@ cc_library(
2728
"remove_dropout.cpp",
2829
"remove_nops.cpp",
2930
"remove_unnecessary_casts.cpp",
31+
"rewrite_inputs_with_params.cpp",
3032
"silu_to_sigmoid_multiplication.cpp",
3133
"unpack_addmm.cpp",
3234
"unpack_batch_norm.cpp",

core/lowering/passes/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
target_sources(${lib_name}
22
PRIVATE "${CMAKE_CURRENT_SOURCE_DIR}/convNd_to_convolution.cpp"
3+
"${CMAKE_CURRENT_SOURCE_DIR}/device_casting.cpp"
34
"${CMAKE_CURRENT_SOURCE_DIR}/exception_elimination.cpp"
45
"${CMAKE_CURRENT_SOURCE_DIR}/fuse_addmm_branches.cpp"
56
"${CMAKE_CURRENT_SOURCE_DIR}/linear_to_addmm.cpp"
@@ -24,6 +25,7 @@ target_sources(${lib_name}
2425
"${CMAKE_CURRENT_SOURCE_DIR}/unpack_std.cpp"
2526
"${CMAKE_CURRENT_SOURCE_DIR}/unpack_var.cpp"
2627
"${CMAKE_CURRENT_SOURCE_DIR}/view_to_reshape.cpp"
28+
"${CMAKE_CURRENT_SOURCE_DIR}/rewrite_inputs_with_params.cpp"
2729
)
2830

2931
set(HEADER_FILES

0 commit comments

Comments
 (0)