Skip to content

Commit 768edcb

Browse files
committed
tests: Linting and clean up
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 2f11791 commit 768edcb

File tree

5 files changed

+22
-17
lines changed

5 files changed

+22
-17
lines changed

core/conversion/conversion.cpp

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,7 @@ bool OpSupported(const torch::jit::Node* n) {
1818
return evaluators::shouldEvalAtConversionTime(n) || converters::node_is_convertable(n);
1919
}
2020

21-
c10::optional<torch::jit::IValue> EvaluateNode(
22-
ConversionCtx* ctx,
23-
const torch::jit::Node* n,
24-
int level,
25-
int limit) {
21+
c10::optional<torch::jit::IValue> EvaluateNode(ConversionCtx* ctx, const torch::jit::Node* n, int level, int limit) {
2622
// Check to see if you can just go through and eval all of these AOT (saves
2723
// the recursion) Also probably a better way to deal with the two error cases;
2824
TRTORCH_CHECK(

core/conversion/conversion.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,11 @@ bool OpSupported(const torch::jit::Node* n);
4242

4343
bool VerifyConverterSupportForBlock(const torch::jit::Block* b);
4444

45-
c10::optional<torch::jit::IValue> EvaluateNode(ConversionCtx* ctx, const torch::jit::Node* n, int level = 0, int limit = 10);
45+
c10::optional<torch::jit::IValue> EvaluateNode(
46+
ConversionCtx* ctx,
47+
const torch::jit::Node* n,
48+
int level = 0,
49+
int limit = 10);
4650

4751
} // namespace conversion
4852
} // namespace core

tests/core/lowering/test_remove_contiguous_pass.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
#include <string>
2-
#include "gtest/gtest.h"
3-
#include "tests/util/util.h"
42
#include "core/compiler.h"
53
#include "core/lowering/passes/passes.h"
4+
#include "gtest/gtest.h"
5+
#include "tests/util/util.h"
66
#include "torch/csrc/jit/ir/irparser.h"
77
#include "torch/csrc/jit/ir/subgraph_matcher.h"
88

tests/util/evaluate_graph.cpp

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,23 +7,24 @@
77
#include "core/conversion/var/Var.h"
88
#include "core/util/prelude.h"
99

10-
1110
namespace trtorch {
1211
namespace tests {
1312
namespace util {
1413

1514
std::vector<torch::jit::IValue> EvaluateGraph(const torch::jit::Block* b, std::vector<torch::jit::IValue> inputs) {
15+
LOG_DEBUG("Running TRTorch Version");
16+
1617
core::conversion::ConversionCtx* ctx = new core::conversion::ConversionCtx({});
17-
std::cout << "IJKWIOJWQIJOWQ" << std::endl;
1818

1919
TRTORCH_CHECK(inputs.size() == b->inputs().size(), "Amount of provided inputs do not match number of graph inputs");
2020
for (size_t i = 0; i < inputs.size(); i++) {
2121
ctx->AssociateValueAndIValue(b->inputs()[i], inputs[i]);
2222
}
23-
std::cout << "IJKWIOJWQIJOWQ" << std::endl;
23+
2424
for (const auto n : b->nodes()) {
25-
std::cout << *n << std::endl;
26-
TRTORCH_CHECK(core::conversion::evaluators::shouldEvalAtConversionTime(n), "Test graph contains non evaluatable nodes: " << *n);
25+
TRTORCH_CHECK(
26+
core::conversion::evaluators::shouldEvalAtConversionTime(n),
27+
"Test graph contains non evaluatable nodes: " << *n);
2728
auto eval = core::conversion::EvaluateNode(ctx, n);
2829
if (eval) {
2930
if (!eval.value().isTensor()) {
@@ -36,7 +37,7 @@ std::vector<torch::jit::IValue> EvaluateGraph(const torch::jit::Block* b, std::v
3637
}
3738

3839
std::vector<torch::jit::IValue> outputs;
39-
for(auto o : b->outputs()) {
40+
for (auto o : b->outputs()) {
4041
auto it = ctx->evaluated_value_map.find(o);
4142
TRTORCH_CHECK(
4243
it != ctx->evaluated_value_map.end(),
@@ -48,12 +49,14 @@ std::vector<torch::jit::IValue> EvaluateGraph(const torch::jit::Block* b, std::v
4849
return outputs;
4950
}
5051

51-
std::vector<torch::jit::IValue> EvaluateGraphJIT(std::shared_ptr<torch::jit::Graph>& g, std::vector<torch::jit::IValue> inputs) {
52+
std::vector<torch::jit::IValue> EvaluateGraphJIT(
53+
std::shared_ptr<torch::jit::Graph>& g,
54+
std::vector<torch::jit::IValue> inputs) {
5255
LOG_DEBUG("Running JIT version");
5356

5457
torch::jit::GraphExecutor executor(g, "");
5558
auto stack = torch::jit::Stack();
56-
for (auto& i : inputs) {
59+
for (auto& i : inputs) {
5760
torch::jit::push(stack, i);
5861
}
5962

tests/util/util.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,9 @@ std::vector<at::Tensor> RunModuleForwardAsEngine(torch::jit::Module& mod, std::v
4747
std::vector<torch::jit::IValue> EvaluateGraph(const torch::jit::Block* b, std::vector<torch::jit::IValue> inputs);
4848

4949
// Runs evaluatable graphs through the JIT interpreter and returns results
50-
std::vector<torch::jit::IValue> EvaluateGraphJIT(std::shared_ptr<torch::jit::Graph>& g, std::vector<torch::jit::IValue> inputs);
50+
std::vector<torch::jit::IValue> EvaluateGraphJIT(
51+
std::shared_ptr<torch::jit::Graph>& g,
52+
std::vector<torch::jit::IValue> inputs);
5153
} // namespace util
5254
} // namespace tests
5355
} // namespace trtorch

0 commit comments

Comments
 (0)