Skip to content

Commit 86982e1

Browse files
committed
chore: Add test case
Signed-off-by: Dheeraj Peri <[email protected]>
1 parent 8e60a54 commit 86982e1

File tree

1 file changed

+32
-0
lines changed

1 file changed

+32
-0
lines changed

tests/cpp/test_dynamic_fallback.cpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
#include <string>
2+
#include "gtest/gtest.h"
3+
#include "tests/util/util.h"
4+
#include "torch/script.h"
5+
#include "torch_tensorrt/torch_tensorrt.h"
6+
7+
TEST(CppAPITest, ResNet50DynamicFallbackGraphCorrectly) {
8+
torch::jit::script::Module mod;
9+
try {
10+
mod = torch::jit::load("tests/modules/resnet18_scripted.jit.pt");
11+
} catch (const c10::Error& e) {
12+
std::cerr << "error loading the model\n";
13+
ASSERT_TRUE(false);
14+
}
15+
16+
const std::vector<std::vector<int64_t>> input_shapes = {{1, 3, 224, 224}, {4, 3, 224, 224}, {8, 3, 224, 224}};
17+
std::vector<torch::jit::IValue> jit_inputs_ivalues;
18+
std::vector<torch::jit::IValue> trt_inputs_ivalues;
19+
auto in = at::randint(5, input_shapes[0], {at::kCUDA});
20+
jit_inputs_ivalues.push_back(in.clone());
21+
trt_inputs_ivalues.push_back(in.clone());
22+
23+
std::vector<torch_tensorrt::Input> inputs;
24+
inputs.push_back(torch_tensorrt::Input(input_shapes[0], input_shapes[1], input_shapes[2]));
25+
torch_tensorrt::ts::CompileSpec cfg(inputs);
26+
cfg.torch_executed_ops.push_back("aten::add");
27+
28+
auto jit_results = mod.forward(jit_inputs_ivalues).toTensor();
29+
auto trt_mod = torch_tensorrt::ts::compile(mod, cfg);
30+
auto trt_results = trt_mod.forward(trt_inputs_ivalues).toTensor();
31+
ASSERT_TRUE(torch_tensorrt::tests::util::cosineSimEqual(jit_results, trt_results));
32+
}

0 commit comments

Comments
 (0)