Skip to content

Commit cd6b1b9

Browse files
committed
docs: Clean up testing and documentation
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent a9f33e4 commit cd6b1b9

File tree

16 files changed

+52
-24
lines changed

16 files changed

+52
-24
lines changed

BUILD

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ pkg_tar(
1111
"//core/conversion/evaluators:include",
1212
"//core/execution:include",
1313
"//core/lowering:include",
14-
"//core/lowering/irfusers:include",
14+
"//core/lowering/passes:include",
1515
"//core/util:include",
1616
"//core/util/logging:include"
1717
],

README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,10 @@ Thanks for wanting to contribute! There are two main ways to handle supporting a
149149
150150
You can register a converter for your op using the `NodeConverterRegistry` inside your application.
151151

152+
## Known Limitations
153+
154+
- You cannot use both Adaptive Pooling in PyTorch and also use TRTorch Dynamic input shape
155+
152156
## Structure of the repo
153157

154158
| Component | Description |

core/compiler.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ torch::jit::script::Module CompileGraph(const torch::jit::script::Module& mod,
102102
ExtraInfo cfg) {
103103
// TODO: Should be doing a functional transform but need PR #31978
104104
// [jit] More robust mangling
105-
// torch::jit::script::Module new_mod = mod.clone();
105+
//torch::jit::script::Module new_mod = mod.clone();
106106
torch::jit::script::Module new_mod(mod._ivalue()->name() + "_trt");
107107
std::vector<std::shared_ptr<torch::jit::Graph>> graphs;
108108
for (const torch::jit::script::Method& method : mod.get_methods()) {

core/lowering/lowering.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g) {
2626
passes::RemoveDropout(g);
2727
passes::FuseFlattenLinear(g);
2828
passes::UnpackAddMM(g);
29-
passes::ExpandLogSoftmax(g);
29+
passes::UnpackLogSoftmax(g);
3030
//passes::RemoveDimExeception(g);
3131
//irfusers::UnpackBatchNorm(g);
3232
torch::jit::EliminateDeadCode(g);

core/lowering/passes/BUILD

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,12 @@ cc_library(
66
"passes.h",
77
],
88
srcs = [
9+
"exception_elimination.cpp",
910
"fuse_flatten_linear.cpp",
10-
"expand_log_softmax.cpp",
1111
"remove_dropout.cpp",
12+
"unpack_addmm.cpp",
1213
"unpack_batch_norm.cpp",
13-
"exception_elimination.cpp",
14-
"unpack_addmm.cpp"
14+
"unpack_log_softmax.cpp",
1515
],
1616
deps = [
1717
"//core/util:prelude",
@@ -23,7 +23,7 @@ load("@rules_pkg//:pkg.bzl", "pkg_tar")
2323

2424
pkg_tar(
2525
name = "include",
26-
package_dir = "core/lowering/irfusers/",
27-
srcs = ["irfusers.h"],
26+
package_dir = "core/lowering/passes/",
27+
srcs = ["passes.h"],
2828
)
2929

core/lowering/passes/exception_elimination.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ struct ExceptionOrPassPatternElimination {
2121
: graph_(std::move(graph)) {}
2222

2323
void run() {
24-
LOG_GRAPH("Pre exeception or pass elimination: " << *graph_);
2524
findExceptionOrPassNodes(graph_->block());
2625
torch::jit::EliminateDeadCode(graph_);
2726
LOG_GRAPH("Post exeception or pass elimination: " << *graph_);

core/lowering/passes/fuse_flatten_linear.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
#include "torch/csrc/jit/passes/fuse_linear.h"
22
#include "torch/csrc/jit/passes/subgraph_rewrite.h"
33

4+
#include "core/util/prelude.h"
5+
46
namespace trtorch {
57
namespace core {
68
namespace lowering {
@@ -38,6 +40,7 @@ void FuseFlattenLinear(std::shared_ptr<torch::jit::Graph>& graph) {
3840
flatten_linear_bias_none_to_linear.RegisterRewritePattern(
3941
flatten_linear_bias_none_pattern, fused_linear_bias_none);
4042
flatten_linear_bias_none_to_linear.runOnGraph(graph);
43+
LOG_GRAPH("Post flatten linear: " << *graph);
4144
}
4245

4346
} // namespace passes

core/lowering/passes/passes.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,10 @@ namespace lowering {
88
namespace passes {
99

1010
void FuseFlattenLinear(std::shared_ptr<torch::jit::Graph>& graph);
11-
void ExpandLogSoftmax(std::shared_ptr<torch::jit::Graph>& graph);
1211
void RemoveDropout(std::shared_ptr<torch::jit::Graph>& graph);
13-
void UnpackBatchNorm(std::shared_ptr<torch::jit::Graph>& graph);
1412
void UnpackAddMM(std::shared_ptr<torch::jit::Graph>& graph);
13+
void UnpackBatchNorm(std::shared_ptr<torch::jit::Graph>& graph);
14+
void UnpackLogSoftmax(std::shared_ptr<torch::jit::Graph>& graph);
1515
void EliminateExceptionOrPassPattern(std::shared_ptr<torch::jit::Graph> graph);
1616

1717
} // namespace irfusers

core/lowering/passes/remove_dropout.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
#include <torch/csrc/jit/passes/fuse_linear.h>
22
#include <torch/csrc/jit/passes/subgraph_rewrite.h>
33

4+
#include "core/util/prelude.h"
5+
46
namespace trtorch {
57
namespace core {
68
namespace lowering {
@@ -20,6 +22,7 @@ void RemoveDropout(std::shared_ptr<torch::jit::Graph>& graph) {
2022
remove_dropout.RegisterRewritePattern(
2123
dropout_pattern, no_dropout_pattern);
2224
remove_dropout.runOnGraph(graph);
25+
LOG_GRAPH("Post remove dropout: " << *graph);
2326
}
2427

2528
} // namespace passes

core/lowering/passes/unpack_addmm.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
#include "torch/csrc/jit/passes/fuse_linear.h"
22
#include "torch/csrc/jit/passes/subgraph_rewrite.h"
33

4+
#include "core/util/prelude.h"
5+
46
namespace trtorch {
57
namespace core {
68
namespace lowering {
@@ -23,6 +25,7 @@ void UnpackAddMM(std::shared_ptr<torch::jit::Graph>& graph) {
2325
torch::jit::SubgraphRewriter unpack_addmm;
2426
unpack_addmm.RegisterRewritePattern(addmm_pattern, mm_add_pattern);
2527
unpack_addmm.runOnGraph(graph);
28+
LOG_GRAPH("Post unpack addmm: " << *graph);
2629
}
2730

2831

0 commit comments

Comments
 (0)