11
11
// const c10::optional<Tensor>& running_mean_opt /* optional */,
12
12
// const c10::optional<Tensor>& running_var_opt /* optional */,
13
13
// bool use_input_stats, double momentum, double eps, bool cudnn_enabled)
14
- inline constexpr auto graph = R"IR(
15
- graph(%0 : Tensor,
16
- %1 : Tensor?,
17
- %2 : Tensor?,
18
- %3 : Tensor?,
19
- %4 : Tensor?,
20
- %5 : bool):
21
- %9 : bool = prim::Constant[value=0]()
22
- %6 : float = prim::Constant[value=0.10000000000000001]()
23
- %7 : float = prim::Constant[value=1.0000000000000001e-05]()
24
- %8 : Tensor = aten::instance_norm(%0, %1, %2, %3, %4, %5, %6, %7, %9)
25
- return (%8)
14
+ constexpr auto graph = R"IR(
15
+ graph(%input.1 : Tensor,
16
+ %weight.1 : Tensor?,
17
+ %bias.1 : Tensor?,
18
+ %running_mean.1 : Tensor?,
19
+ %running_var.1 : Tensor?,
20
+ %use_input_stats.1 : bool):
21
+ %cudnn_enabled.1 : bool = prim::Constant[value=1]()
22
+ %momentum.1 : float = prim::Constant[value=0.10000000000000001]()
23
+ %eps.1 : float = prim::Constant[value=1.0000000000000001e-05]()
24
+ %4 : Tensor = aten::instance_norm(%input.1,
25
+ %weight.1, %bias.1,
26
+ %running_mean.1, %running_var.1,
27
+ %use_input_stats.1, %momentum.1, %eps.1, %cudnn_enabled.1)
28
+ return (%4)
26
29
)IR" ;
27
30
31
+
28
32
TEST (Converters, ATenInstanceNormConvertsCorrectly) {
29
33
auto g = std::make_shared<torch::jit::Graph>();
30
34
torch::jit::parseIR (graph, g.get ());
31
35
32
36
auto in = at::randint (1 , 10 , {1 , 5 , 5 , 5 }, {at::kCUDA });
33
37
torch::jit::IValue weight, bias, mean, var; // NoneType
34
- bool use_input_stats = true ;
38
+ // https://github.com/pytorch/pytorch/blob/79693bb86a3f601a5c0d3da52d99acec95bb48c1/torch/nn/modules/instancenorm.py#L59
39
+ const bool use_input_stats = true ;
40
+
41
+ auto trt_in = at::clone (in);
42
+ torch::jit::IValue trt_weight, trt_bias, trt_mean, trt_var;
35
43
36
44
auto params = trtorch::core::conversion::get_named_params (g->inputs (), {weight, bias, mean, var, use_input_stats});
37
45
auto jit_results = trtorch::tests::util::RunGraph (g, params, {in});
38
46
39
- params = trtorch::core::conversion::get_named_params (g->inputs (), {weight, bias, mean, var , use_input_stats});
40
- auto trt_results = trtorch::tests::util::RunGraphEngine (g, params, {in });
47
+ params = trtorch::core::conversion::get_named_params (g->inputs (), {trt_weight, trt_bias, trt_mean, trt_var , use_input_stats});
48
+ auto trt_results = trtorch::tests::util::RunGraphEngine (g, params, {trt_in });
41
49
42
50
ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt_results[0 ].reshape_as (jit_results[0 ]), 2e-6 ));
43
51
}
@@ -50,37 +58,44 @@ TEST(Converters, ATenInstanceNormAffineConvertsCorrectly) {
50
58
51
59
auto weight = at::randn ({in.size (1 )}).to (at::kCUDA );
52
60
auto bias = at::randn ({in.size (1 )}).to (at::kCUDA );
61
+
62
+ torch::jit::IValue mean, var; // NoneType
63
+ const bool use_input_stats = true ;
53
64
54
- torch::jit::IValue mean, var; // NoneType
55
- bool use_input_stats = true ;
65
+ auto trt_in = at::clone (in);
66
+ auto trt_weight = at::clone (weight);
67
+ auto trt_bias = at::clone (bias);
68
+ torch::jit::IValue trt_mean, trt_var;
56
69
57
70
auto params = trtorch::core::conversion::get_named_params (g->inputs (), {weight, bias, mean, var, use_input_stats});
58
71
auto jit_results = trtorch::tests::util::RunGraph (g, params, {in});
59
72
60
- params = trtorch::core::conversion::get_named_params (g->inputs (), {weight, bias, mean, var , use_input_stats});
61
- auto trt_results = trtorch::tests::util::RunGraphEngine (g, params, {in });
73
+ params = trtorch::core::conversion::get_named_params (g->inputs (), {trt_weight, trt_bias, trt_mean, trt_var , use_input_stats});
74
+ auto trt_results = trtorch::tests::util::RunGraphEngine (g, params, {trt_in });
62
75
63
76
ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt_results[0 ].reshape_as (jit_results[0 ]), 2e-6 ));
64
77
}
65
78
66
-
67
79
TEST (Converters, ATenInstanceNormRunningStatsConvertsCorrectly) {
68
80
auto g = std::make_shared<torch::jit::Graph>();
69
81
torch::jit::parseIR (graph, g.get ());
70
82
71
- auto in = at::randint (1 , 10 , {1 , 5 , 5 , 5 }, {at::kCUDA });
72
-
73
- torch::jit::IValue weight, bias; // NoneType
74
-
75
- auto mean = at::randn ({in.size (1 )}).to (at::kCUDA );
76
- auto var = at::randn ({in.size (1 )}).to (at::kCUDA );
77
- bool use_input_stats = false ;
83
+ auto in = at::randn ({1 , 5 , 5 , 5 }, {at::kCUDA });
84
+
85
+ torch::jit::IValue weight, bias;
86
+ auto mean = at::zeros ({in.size (1 )}, {at::kCUDA });
87
+ auto var = at::ones ({in.size (1 )}, {at::kCUDA });
88
+ const bool use_input_stats = false ;
89
+
90
+ auto trt_in = at::clone (in);
91
+ torch::jit::IValue trt_weight, trt_bias;
92
+ auto trt_mean = at::clone (mean);
93
+ auto trt_var = at::clone (var);
78
94
79
95
auto params = trtorch::core::conversion::get_named_params (g->inputs (), {weight, bias, mean, var, use_input_stats});
80
96
auto jit_results = trtorch::tests::util::RunGraph (g, params, {in});
81
97
82
- params = trtorch::core::conversion::get_named_params (g->inputs (), {weight, bias, mean, var, use_input_stats});
83
- auto trt_results = trtorch::tests::util::RunGraphEngine (g, params, {in});
84
-
98
+ params = trtorch::core::conversion::get_named_params (g->inputs (), {trt_weight, trt_bias, trt_mean, trt_var, use_input_stats});
99
+ auto trt_results = trtorch::tests::util::RunGraphEngine (g, params, {trt_in});
85
100
ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt_results[0 ].reshape_as (jit_results[0 ]), 2e-6 ));
86
101
}
0 commit comments