Skip to content

Commit 12942ac

Browse files
committed
refactor(//tests): Fixing batchnorm false test
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 6f134fa commit 12942ac

File tree

1 file changed

+4
-6
lines changed

1 file changed

+4
-6
lines changed

tests/core/conversion/converters/test_batch_norm.cpp

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,29 +40,27 @@ TEST(Converters, ATenBatchNormAffineFalseConvertsCorrectly) {
4040
// BatchNorm(ch, affine=False)
4141
const auto graph = R"IR(
4242
graph(%0 : Tensor,
43-
%1: NoneType = prim::Constant(),
44-
%2: NoneType = prim::Constant(),
4543
%3: Float(5, strides=[1]),
4644
%4: Float(5, strides=[1])):
45+
%1 : None = prim::Constant()
4746
%5 : bool = prim::Constant[value=0]()
4847
%6 : float = prim::Constant[value=1.0000000000000001e-05]()
4948
%7 : float = prim::Constant[value=0.10000000000000001]()
50-
%8 : Tensor = aten::batch_norm(%0, %1, %2, %3, %4, %5, %6, %7, %5)
49+
%8 : Tensor = aten::batch_norm(%0, %1, %1, %3, %4, %5, %6, %7, %5)
5150
return (%8))IR";
5251

5352
auto g = std::make_shared<torch::jit::Graph>();
5453
torch::jit::parseIR(graph, g.get());
5554

5655
auto in = at::randint(1, 10, {1, 5, 5, 5}, {at::kCUDA});
5756

58-
torch::jit::IValue gamma, beta; // NoneType
5957
auto mean = at::randint(1, 10, {5}, {at::kCUDA});
6058
auto var = at::randint(1, 10, {5}, {at::kCUDA});
6159

62-
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {gamma, beta, mean, var});
60+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {mean, var});
6361
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in});
6462

65-
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {gamma, beta, mean, var});
63+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {mean, var});
6664
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in});
6765

6866
ASSERT_TRUE(

0 commit comments

Comments
 (0)