Skip to content

Commit 3dd9daa

Browse files
authored
Merge pull request #268 from NVIDIA/testing
Adds testing infrastructure for evaluators and lowering passes
2 parents 553ef02 + 768edcb commit 3dd9daa

34 files changed

+234
-20
lines changed

core/conversion/conversion.cpp

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,7 @@ bool OpSupported(const torch::jit::Node* n) {
1818
return evaluators::shouldEvalAtConversionTime(n) || converters::node_is_convertable(n);
1919
}
2020

21-
c10::optional<torch::jit::IValue> EvaluateNode(
22-
ConversionCtx* ctx,
23-
const torch::jit::Node* n,
24-
int level = 0,
25-
int limit = 10) {
21+
c10::optional<torch::jit::IValue> EvaluateNode(ConversionCtx* ctx, const torch::jit::Node* n, int level, int limit) {
2622
// Check to see if you can just go through and eval all of these AOT (saves
2723
// the recursion) Also probably a better way to deal with the two error cases;
2824
TRTORCH_CHECK(

core/conversion/conversion.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,12 @@ bool OpSupported(const torch::jit::Node* n);
4242

4343
bool VerifyConverterSupportForBlock(const torch::jit::Block* b);
4444

45+
c10::optional<torch::jit::IValue> EvaluateNode(
46+
ConversionCtx* ctx,
47+
const torch::jit::Node* n,
48+
int level = 0,
49+
int limit = 10);
50+
4551
} // namespace conversion
4652
} // namespace core
4753
} // namespace trtorch

tests/BUILD

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,32 @@
11
test_suite(
22
name = "tests",
33
tests = [
4-
"//tests/core/converters:test_converters",
5-
"//tests/modules:test_modules"
4+
"//tests/core:core_tests",
5+
"//tests/modules:module_tests"
66
],
77
)
88

99
test_suite(
1010
name = "required_and_optional_tests",
1111
tests = [
1212
":tests",
13-
"//tests/accuracy:test_accuracy"
13+
"//tests/accuracy:accuracy_tests"
1414
]
1515
)
1616

1717
test_suite(
1818
name = "aarch64_tests",
1919
tests = [
20-
"//tests/core/converters:test_converters",
21-
"//tests/modules:test_modules_aarch64"
20+
"//tests/core:core_tests",
21+
"//tests/modules:aarch64_module_tests"
2222
],
2323
)
2424

2525
test_suite(
2626
name = "required_and_optional_aarch64_tests",
2727
tests = [
28-
":aarch64_tests",
29-
"//tests/accuracy:test_accuracy_aarch64"
28+
":aarch64_tests",
29+
"//tests/accuracy:aarch64_accuracy_tests"
3030
]
3131
)
3232

tests/accuracy/BUILD

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ filegroup(
44
)
55

66
test_suite(
7-
name = "test_accuracy_aarch64",
7+
name = "aarch64_accuracy_tests",
88
tests = [
99
":test_dla_int8_accuracy",
1010
":test_dla_fp16_accuracy",
@@ -15,7 +15,7 @@ test_suite(
1515
)
1616

1717
test_suite(
18-
name = "test_accuracy",
18+
name = "accuracy_tests",
1919
tests = [
2020
":test_int8_accuracy",
2121
":test_fp16_accuracy",

tests/core/BUILD

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
test_suite(
2+
name = "core_tests",
3+
tests = [
4+
"//tests/core/conversion:conversion_tests",
5+
"//tests/core/lowering:lowering_tests",
6+
],
7+
)

tests/core/conversion/BUILD

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
test_suite(
2+
name = "conversion_tests",
3+
tests = [
4+
"//tests/core/conversion/converters:converter_tests",
5+
"//tests/core/conversion/evaluators:evaluator_tests",
6+
],
7+
)

tests/core/converters/BUILD renamed to tests/core/conversion/converters/BUILD

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
load("//tests/core/converters:converter_test.bzl", "converter_test")
1+
load("//tests/core/conversion/converters:converter_test.bzl", "converter_test")
22

33
config_setting(
44
name = "use_pre_cxx11_abi",
@@ -72,7 +72,7 @@ converter_test(
7272
)
7373

7474
test_suite(
75-
name = "test_converters",
75+
name = "converter_tests",
7676
tests = [
7777
":test_activation",
7878
":test_batch_norm",

0 commit comments

Comments
 (0)