Skip to content

Commit 20022d4

Browse files
authored
Merge pull request #288 from NVIDIA/to_backend_device
Adding the new device API, fixing the a nested dict issue in the existing compile phase, adding new lowering pass for bn
2 parents b787c5e + 86bb5b7 commit 20022d4

File tree

13 files changed

+167
-56
lines changed

13 files changed

+167
-56
lines changed

.bazelversion

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
3.7.0
1+
4.0.0

WORKSPACE

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,10 +79,10 @@ http_archive(
7979

8080
http_archive(
8181
name = "tensorrt",
82-
urls = ["https://developer.nvidia.com/compute/machine-learning/tensorrt/secure/7.2.1/tars/TensorRT-7.2.1.6.Ubuntu-18.04.x86_64-gnu.cuda-11.0.cudnn8.0.tar.gz",],
82+
urls = ["https://developer.nvidia.com/compute/machine-learning/tensorrt/secure/7.2.2/tars/TensorRT-7.2.2.3.Ubuntu-18.04.x86_64-gnu.cuda-11.0.cudnn8.0.tar.gz",],
8383
build_file = "@//third_party/tensorrt/archive:BUILD",
84-
sha256 = "8def6b03b0c8c3751f560df21b3e99668ae05aab5140b1d38b8e51e4a0ffbbb8",
85-
strip_prefix = "TensorRT-7.2.1.6"
84+
strip_prefix = "TensorRT-7.2.2.3",
85+
sha256 = "b5c325e38e1d92ce1ce92ca8b54ede9c224bf128c9a53eb0b9022f1ee4313ee0"
8686
)
8787

8888
####################################################################################

core/lowering/lowering.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g) {
4040
passes::Conv2DToConvolution(g);
4141
passes::Conv3DToConvolution(g);
4242
passes::FuseAddMMBranches(g);
43+
passes::RemoveBNDimCheck(g);
4344
torch::jit::EliminateCommonSubexpression(g);
4445
// torch::jit::UnrollLoops(g);
4546
torch::jit::EliminateCommonSubexpression(g);

core/lowering/passes/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ cc_library(
1818
"exception_elimination.cpp",
1919
"fuse_addmm_branches.cpp",
2020
"fuse_flatten_linear.cpp",
21+
"remove_bn_dim_check.cpp",
2122
"remove_contiguous.cpp",
2223
"remove_dropout.cpp",
2324
"remove_to.cpp",

core/lowering/passes/passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ void Conv3DToConvolution(std::shared_ptr<torch::jit::Graph>& graph);
1212
void FuseAddMMBranches(std::shared_ptr<torch::jit::Graph> graph);
1313
void FuseFlattenLinear(std::shared_ptr<torch::jit::Graph>& graph);
1414
void EliminateExceptionOrPassPattern(std::shared_ptr<torch::jit::Graph> graph);
15+
void RemoveBNDimCheck(std::shared_ptr<torch::jit::Graph> graph);
1516
void RemoveContiguous(std::shared_ptr<torch::jit::Graph>& graph);
1617
void RemoveDropout(std::shared_ptr<torch::jit::Graph>& graph);
1718
void RemoveTo(std::shared_ptr<torch::jit::Graph> graph);
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
#include "torch/csrc/jit/ir/alias_analysis.h"
2+
#include "torch/csrc/jit/jit_log.h"
3+
#include "torch/csrc/jit/passes/constant_propagation.h"
4+
#include "torch/csrc/jit/passes/dead_code_elimination.h"
5+
#include "torch/csrc/jit/passes/guard_elimination.h"
6+
#include "torch/csrc/jit/passes/peephole.h"
7+
#include "torch/csrc/jit/runtime/graph_executor.h"
8+
9+
#include "core/util/prelude.h"
10+
11+
#include <vector>
12+
13+
namespace trtorch {
14+
namespace core {
15+
namespace lowering {
16+
namespace passes {
17+
namespace {
18+
using namespace torch::jit;
19+
struct BNDimCheckRemoval {
20+
BNDimCheckRemoval(std::shared_ptr<Graph> graph) : graph_(std::move(graph)) {}
21+
22+
void run() {
23+
findBNDimCheckNodes(graph_->block());
24+
torch::jit::EliminateDeadCode(graph_);
25+
LOG_GRAPH("Post batch norm dim check removal: " << *graph_);
26+
}
27+
28+
private:
29+
bool isBNDimCheckNodes(Node* n) {
30+
/// Check if this Node hosts a pattern like so:
31+
/// %290 : bool = aten::ne(%289, %9)
32+
/// = prim::If(%290)
33+
/// block0():
34+
/// %291 : str = aten::format(%10, %289)
35+
/// = prim::RaiseException(%291)
36+
/// -> ()
37+
/// block1():
38+
/// -> ()
39+
40+
if (n->blocks().size() != 2) {
41+
return false;
42+
}
43+
auto arm1 = n->blocks()[0];
44+
auto arm2 = n->blocks()[1];
45+
if (arm1->outputs().size() != 0 || arm2->outputs().size() != 0) {
46+
// Make sure that the node doesn't actually produce any Value that are
47+
// used by other nodes
48+
return false;
49+
}
50+
51+
auto arm1_start = arm1->nodes().begin();
52+
53+
if ((*arm1_start)->kind() != c10::Symbol::fromQualString("aten::format") &&
54+
(*(++arm1_start))->kind() != prim::RaiseException && (*(++arm1_start))->kind() != prim::Return) {
55+
// Make sure that block0 is solely just the exception and the return
56+
return false;
57+
}
58+
59+
if ((*(arm2->nodes().begin()))->kind() != prim::Return) {
60+
// Make sure that block1 is solely the return
61+
return false;
62+
}
63+
64+
return true;
65+
}
66+
67+
void findBNDimCheckNodes(Block* b) {
68+
for (auto it = b->nodes().begin(); it != b->nodes().end(); it++) {
69+
auto n = *it;
70+
if (n->kind() == prim::If && isBNDimCheckNodes(n)) {
71+
LOG_GRAPH("Found that node " << *n << " is an batch norm dim check node (EliminateChecks)" << std::endl);
72+
it.destroyCurrent();
73+
}
74+
}
75+
}
76+
77+
std::shared_ptr<Graph> graph_;
78+
};
79+
} // namespace
80+
81+
void RemoveBNDimCheck(std::shared_ptr<Graph> graph) {
82+
BNDimCheckRemoval bndcr(std::move(graph));
83+
bndcr.run();
84+
}
85+
86+
} // namespace passes
87+
} // namespace lowering
88+
} // namespace core
89+
} // namespace trtorch

docsrc/tutorials/use_from_pytorch.rst

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -32,31 +32,36 @@ at the documentation for the TRTorch ``TensorRTCompileSpec`` API.
3232
.. code-block:: python
3333
3434
spec = {
35-
"forward": trtorch.TensorRTCompileSpec({
36-
"input_shapes": [[1, 3, 300, 300]],
37-
"op_precision": torch.half,
38-
"refit": False,
39-
"debug": False,
40-
"strict_types": False,
41-
"allow_gpu_fallback": True,
42-
"device_type": "gpu",
43-
"capability": trtorch.EngineCapability.default,
44-
"num_min_timing_iters": 2,
45-
"num_avg_timing_iters": 1,
46-
"max_batch_size": 0,
47-
})
48-
}
35+
"forward":
36+
trtorch.TensorRTCompileSpec({
37+
"input_shapes": [[1, 3, 300, 300]],
38+
"op_precision": torch.half,
39+
"refit": False,
40+
"debug": False,
41+
"strict_types": False,
42+
"device": {
43+
"device_type": trtorch.DeviceType.GPU,
44+
"gpu_id": 0,
45+
"dla_core": 0,
46+
"allow_gpu_fallback": True
47+
},
48+
"capability": trtorch.EngineCapability.default,
49+
"num_min_timing_iters": 2,
50+
"num_avg_timing_iters": 1,
51+
"max_batch_size": 0,
52+
})
53+
}
4954
5055
Now to compile with TRTorch, provide the target module objects and the spec dictionary to ``torch._C._jit_to_tensorrt``
5156

5257
.. code-block:: python
5358
54-
trt_model = torch._C._jit_to_tensorrt(script_model._c, spec)
59+
trt_model = torch._C._jit_to_backend("tensorrt", script_model, spec)
5560
5661
To run explicitly call the function of the method you want to run (vs. how you can just call on the module itself in standard PyTorch)
5762

5863
.. code-block:: python
5964
60-
input = torch.randn((1, 3, 300, 300).to("cuda").to(torch.half)
65+
input = torch.randn((1, 3, 300, 300)).to("cuda").to(torch.half)
6166
print(trt_model.forward(input))
6267

py/trtorch/_compile_spec.py

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -147,10 +147,6 @@ def _parse_compile_spec(compile_spec: Dict[str, Any]) -> trtorch._C.CompileSpec:
147147
assert isinstance(compile_spec["strict_types"], bool)
148148
info.strict_types = compile_spec["strict_types"]
149149

150-
if "allow_gpu_fallback" in compile_spec:
151-
assert isinstance(compile_spec["allow_gpu_fallback"], bool)
152-
info.allow_gpu_fallback = compile_spec["allow_gpu_fallback"]
153-
154150
if "device" in compile_spec:
155151
info.device = _parse_device(compile_spec["device"])
156152

@@ -177,7 +173,7 @@ def _parse_compile_spec(compile_spec: Dict[str, Any]) -> trtorch._C.CompileSpec:
177173
return info
178174

179175

180-
def TensorRTCompileSpec(compile_spec: Dict[str, Any]):
176+
def TensorRTCompileSpec(compile_spec: Dict[str, Any]) -> torch.classes.tensorrt.CompileSpec:
181177
"""
182178
Utility to create a formated spec dictionary for using the PyTorch TensorRT backend
183179
@@ -199,10 +195,10 @@ def TensorRTCompileSpec(compile_spec: Dict[str, Any]):
199195
} # Dynamic input shape for input #2
200196
],
201197
"device": {
202-
"device_type": torch.device("cuda"), # Type of device to run engine on (for DLA use trtorch.DeviceType.DLA)
203-
"gpu_id": 0, # Target gpu id to run engine (Use Xavier as gpu id for DLA)
204-
"dla_core": 0, # (DLA only) Target dla core id to run engine
205-
"allow_gpu_fallback": false, # (DLA only) Allow layers unsupported on DLA to run on GPU
198+
"device_type": torch.device("cuda"), # Type of device to run engine on (for DLA use trtorch.DeviceType.DLA)
199+
"gpu_id": 0, # Target gpu id to run engine (Use Xavier as gpu id for DLA)
200+
"dla_core": 0, # (DLA only) Target dla core id to run engine
201+
"allow_gpu_fallback": false, # (DLA only) Allow layers unsupported on DLA to run on GPU
206202
},
207203
"op_precision": torch.half, # Operating precision set to FP16
208204
"refit": False, # enable refit
@@ -235,14 +231,13 @@ def TensorRTCompileSpec(compile_spec: Dict[str, Any]):
235231
ir.set_max(i.max)
236232
backend_spec.append_input_range(ir)
237233

238-
for i in parsed_spec.device:
239-
ir = torch.classes.tensorrt.Device()
240-
ir.set_device_type(i.device_type)
241-
ir.set_gpu_id(i.gpu_id)
242-
ir.set_dla_core(i.dla_core)
243-
ir.set_allow_gpu_fallback(i.allow_gpu_fallback)
244-
backend_spec.set_device(ir)
234+
d = torch.classes.tensorrt.Device()
235+
d.set_device_type(int(parsed_spec.device.device_type))
236+
d.set_gpu_id(parsed_spec.device.gpu_id)
237+
d.set_dla_core(parsed_spec.device.dla_core)
238+
d.set_allow_gpu_fallback(parsed_spec.device.allow_gpu_fallback)
245239

240+
backend_spec.set_device(d)
246241
backend_spec.set_op_precision(int(parsed_spec.op_precision))
247242
backend_spec.set_refit(parsed_spec.refit)
248243
backend_spec.set_debug(parsed_spec.debug)

py/trtorch/csrc/register_tensorrt_classes.cpp

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,22 +3,32 @@
33
namespace trtorch {
44
namespace backend {
55
namespace {
6-
void RegisterTRTCompileSpec() {
6+
77
#define ADD_FIELD_GET_SET_REGISTRATION(registry, class_name, field_name) \
88
(registry).def("set_" #field_name, &class_name::set_##field_name); \
99
(registry).def("get_" #field_name, &class_name::get_##field_name);
1010

11+
void RegisterTRTCompileSpec() {
1112
static auto TRTORCH_UNUSED TRTInputRangeTSRegistration =
1213
torch::class_<trtorch::pyapi::InputRange>("tensorrt", "InputRange").def(torch::init<>());
1314

1415
ADD_FIELD_GET_SET_REGISTRATION(TRTInputRangeTSRegistration, trtorch::pyapi::InputRange, min);
1516
ADD_FIELD_GET_SET_REGISTRATION(TRTInputRangeTSRegistration, trtorch::pyapi::InputRange, opt);
1617
ADD_FIELD_GET_SET_REGISTRATION(TRTInputRangeTSRegistration, trtorch::pyapi::InputRange, max);
1718

19+
static auto TRTORCH_UNUSED TRTDeviceTSRegistration =
20+
torch::class_<trtorch::pyapi::Device>("tensorrt", "Device").def(torch::init<>());
21+
22+
ADD_FIELD_GET_SET_REGISTRATION(TRTDeviceTSRegistration, trtorch::pyapi::Device, device_type);
23+
ADD_FIELD_GET_SET_REGISTRATION(TRTDeviceTSRegistration, trtorch::pyapi::Device, gpu_id);
24+
ADD_FIELD_GET_SET_REGISTRATION(TRTDeviceTSRegistration, trtorch::pyapi::Device, dla_core);
25+
ADD_FIELD_GET_SET_REGISTRATION(TRTDeviceTSRegistration, trtorch::pyapi::Device, allow_gpu_fallback);
26+
1827
static auto TRTORCH_UNUSED TRTCompileSpecTSRegistration =
1928
torch::class_<trtorch::pyapi::CompileSpec>("tensorrt", "CompileSpec")
2029
.def(torch::init<>())
2130
.def("append_input_range", &trtorch::pyapi::CompileSpec::appendInputRange)
31+
.def("set_device", &trtorch::pyapi::CompileSpec::setDeviceIntrusive)
2232
.def("__str__", &trtorch::pyapi::CompileSpec::stringify);
2333

2434
ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, op_precision);

py/trtorch/csrc/tensorrt_backend.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ c10::impl::GenericDict TensorRTBackend::compile(c10::IValue processed_mod, c10::
4646
auto method = mod.get_method(method_name);
4747
auto g = method.graph();
4848

49-
auto raw_spec = it->value().toGenericDict().at(it->key()).toCustomClass<trtorch::pyapi::CompileSpec>();
49+
auto raw_spec = it->value().toCustomClass<trtorch::pyapi::CompileSpec>();
5050
LOG_DEBUG(raw_spec->stringify());
5151
auto cfg = raw_spec->toInternalCompileSpec();
5252
auto convert_cfg = std::move(cfg.convert_info);

0 commit comments

Comments
 (0)