@@ -40,29 +40,27 @@ TEST(Converters, ATenBatchNormAffineFalseConvertsCorrectly) {
40
40
// BatchNorm(ch, affine=False)
41
41
const auto graph = R"IR(
42
42
graph(%0 : Tensor,
43
- %1: NoneType = prim::Constant(),
44
- %2: NoneType = prim::Constant(),
45
43
%3: Float(5, strides=[1]),
46
44
%4: Float(5, strides=[1])):
45
+ %1 : None = prim::Constant()
47
46
%5 : bool = prim::Constant[value=0]()
48
47
%6 : float = prim::Constant[value=1.0000000000000001e-05]()
49
48
%7 : float = prim::Constant[value=0.10000000000000001]()
50
- %8 : Tensor = aten::batch_norm(%0, %1, %2 , %3, %4, %5, %6, %7, %5)
49
+ %8 : Tensor = aten::batch_norm(%0, %1, %1 , %3, %4, %5, %6, %7, %5)
51
50
return (%8))IR" ;
52
51
53
52
auto g = std::make_shared<torch::jit::Graph>();
54
53
torch::jit::parseIR (graph, g.get ());
55
54
56
55
auto in = at::randint (1 , 10 , {1 , 5 , 5 , 5 }, {at::kCUDA });
57
56
58
- torch::jit::IValue gamma, beta; // NoneType
59
57
auto mean = at::randint (1 , 10 , {5 }, {at::kCUDA });
60
58
auto var = at::randint (1 , 10 , {5 }, {at::kCUDA });
61
59
62
- auto params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {gamma, beta, mean, var});
60
+ auto params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {mean, var});
63
61
auto jit_results = torch_tensorrt::tests::util::RunGraph (g, params, {in});
64
62
65
- params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {gamma, beta, mean, var});
63
+ params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {mean, var});
66
64
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine (g, params, {in});
67
65
68
66
ASSERT_TRUE (
0 commit comments