Skip to content

Commit 6754c79

Browse files
committed
chore: Add testcase
Signed-off-by: Dheeraj Peri <[email protected]>
1 parent 47ae984 commit 6754c79

File tree

5 files changed

+66
-7
lines changed

5 files changed

+66
-7
lines changed

core/conversion/converters/impl/shuffle.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ static auto shuffle_registrations TORCHTRT_UNUSED =
7878
concat_layer->setAxis(static_cast<int32_t>(0));
7979
shape_tensor = concat_layer->getOutput(0);
8080
} else {
81-
auto new_shape = torch::reshape(torch::rand(in_shape), args[1].unwrapToIntList().vec()).sizes().vec();
81+
new_shape = torch::reshape(torch::rand(in_shape), args[1].unwrapToIntList().vec()).sizes().vec();
8282
}
8383
auto shuffle = ctx->net->addShuffle(*in);
8484
shuffle->setName(util::node_info(n).c_str());

core/conversion/evaluators/prim.cpp

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
#include <limits>
22

3-
#include "torch/csrc/jit/ir/ir.h"
4-
//#include "torch/csrc/jit/ir/constants.h"
53
#include "ATen/core/List.h"
64
#include "ATen/core/functional.h"
75
#include "ATen/core/ivalue.h"
86
#include "ATen/core/stack.h"
97
#include "c10/util/intrusive_ptr.h"
8+
#include "torch/csrc/jit/ir/ir.h"
109
#include "torch/torch.h"
1110

1211
#include "core/conversion/evaluators/eval_macros.h"
@@ -111,8 +110,20 @@ auto prim_registrations =
111110
tensor_holder.hold_tensor(itensor);
112111
auto ival = c10::IValue(std::move(c10::make_intrusive<TensorContainer>(tensor_holder)));
113112
list.emplace_back(std::move(ival));
113+
} else if (args.at(in).IValue()->isDouble()) {
114+
auto itensor = torch_tensorrt::core::conversion::converters::tensor_to_const(
115+
ctx, torch::tensor({args.at(in).unwrapToDouble()}).to(torch::kFloat));
116+
auto tensor_holder = TensorContainer();
117+
tensor_holder.hold_tensor(itensor);
118+
auto ival = c10::IValue(std::move(c10::make_intrusive<TensorContainer>(tensor_holder)));
119+
list.emplace_back(std::move(ival));
114120
} else {
115-
list.emplace_back(std::move(args.at(in).unwrapToTensor()));
121+
auto itensor = torch_tensorrt::core::conversion::converters::tensor_to_const(
122+
ctx, std::move(args.at(in).unwrapToTensor()));
123+
auto tensor_holder = TensorContainer();
124+
tensor_holder.hold_tensor(itensor);
125+
auto ival = c10::IValue(std::move(c10::make_intrusive<TensorContainer>(tensor_holder)));
126+
list.emplace_back(std::move(ival));
116127
}
117128
}
118129
}

py/requirements.txt

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
numpy
22
pybind11==2.6.2
3-
--extra-index-url https://download.pytorch.org/whl/nightly/cu117
4-
torch==2.0.0.dev20230103+cu117
5-
torchvision==0.15.0.dev20230103+cu117
3+
torch==1.13.0
4+
torchvision==0.14.0
65
--extra-index-url https://pypi.ngc.nvidia.com
76
tensorrt==8.5.1.7

tests/cpp/BUILD

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ test_suite(
1616
":test_compiled_modules",
1717
":test_default_input_types",
1818
":test_dynamic_fallback",
19+
":test_dynamic_size",
1920
":test_example_tensors",
2021
":test_module_fallback",
2122
":test_modules_as_engines",
@@ -32,6 +33,7 @@ test_suite(
3233
":test_compiled_modules",
3334
":test_default_input_types",
3435
":test_dynamic_fallback",
36+
":test_dynamic_size",
3537
":test_example_tensors",
3638
":test_module_fallback",
3739
":test_modules_as_engines",
@@ -142,6 +144,18 @@ cc_test(
142144
}),
143145
)
144146

147+
cc_test(
148+
name = "test_dynamic_size",
149+
srcs = ["test_dynamic_size.cpp"],
150+
deps = [
151+
"//tests/util",
152+
"@googletest//:gtest_main",
153+
] + select({
154+
":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"],
155+
"//conditions:default": ["@libtorch//:libtorch"],
156+
}),
157+
)
158+
145159
cc_test(
146160
name = "test_collections",
147161
srcs = ["test_collections.cpp"],

tests/cpp/test_dynamic_size.cpp

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
#include <torch/torch.h>
2+
#include <string>
3+
#include "core/compiler.h"
4+
#include "gtest/gtest.h"
5+
#include "tests/util/util.h"
6+
#include "torch/csrc/jit/ir/irparser.h"
7+
8+
TEST(Converters, ATenResizeDynamicInputCorrectly) {
9+
const auto graph = R"IR(
10+
graph(%x : Tensor):
11+
%3 : int = prim::Constant[value=0]()
12+
%2 : int = prim::Constant[value=-1]()
13+
%28 : int = aten::size(%x, %3)
14+
%30 : int[] = prim::ListConstruct(%28, %2)
15+
%6 : Tensor = aten::reshape(%x, %30)
16+
return (%6))IR";
17+
18+
auto g = std::make_shared<torch::jit::Graph>();
19+
20+
torch::jit::parseIR(graph, g.get());
21+
22+
auto in = at::randint(1, 10, {16, 3, 2}, {at::kCUDA});
23+
24+
auto jit_in = at::clone(in);
25+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
26+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in});
27+
28+
auto trt_in = at::clone(in);
29+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
30+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {in}, true);
31+
32+
auto trt = trt_results[0].reshape(jit_results[0].sizes());
33+
34+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6));
35+
}

0 commit comments

Comments
 (0)