Skip to content

Commit d861f4a

Browse files
committed
Fix test cases
Signed-off-by: AhnDW <[email protected]>
1 parent 31d3b90 commit d861f4a

File tree

2 files changed

+49
-35
lines changed

2 files changed

+49
-35
lines changed

core/conversion/converters/impl/batch_norm.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,6 @@ auto batch_norm_registrations TRTORCH_UNUSED =
101101
// track_running_stats=True
102102
LOG_DEBUG("Args[3] running_mean : " << args[3].isIValue() << " / " << args[3].IValue()->isNone());
103103
LOG_DEBUG("Args[4] running_var : " << args[4].isIValue() << " / " << args[4].IValue()->isNone());
104-
105104
LOG_DEBUG("use_input_stats, momemtum, cudnn_enabled disregarded");
106105
LOG_DEBUG("ctx->input_is_dynamic : " << ctx->input_is_dynamic);
107106

@@ -112,15 +111,15 @@ auto batch_norm_registrations TRTORCH_UNUSED =
112111
}
113112

114113
auto eps = static_cast<float>(args[7].unwrapToDouble(1e-5f));
115-
114+
116115
auto scales = args[1].unwrapToTensor(at::ones(shape[1], options)).cpu().contiguous();
117116
auto bias = args[2].unwrapToTensor(at::zeros(shape[1], options)).cpu().contiguous();
118117

119118
// track_running_stats=True
120119
if (!args[3].IValue()->isNone() || !args[4].IValue()->isNone()) {
121-
auto running_mean = args[3].unwrapToTensor().cpu().contiguous();
122-
auto running_var = args[4].unwrapToTensor().cpu().contiguous();
123-
_batch_norm(ctx, n, input, orig_shape, scales, bias, running_mean, running_var, eps);
120+
auto running_mean = args[3].unwrapToTensor();
121+
auto running_var = args[4].unwrapToTensor();
122+
_batch_norm(ctx, n, input, orig_shape, scales.to(running_mean.options()), bias.to(running_mean.options()), running_mean, running_var, eps);
124123
return true;
125124
}
126125

tests/core/conversion/converters/test_instance_norm.cpp

Lines changed: 45 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -11,33 +11,41 @@
1111
// const c10::optional<Tensor>& running_mean_opt /* optional */,
1212
// const c10::optional<Tensor>& running_var_opt /* optional */,
1313
// bool use_input_stats, double momentum, double eps, bool cudnn_enabled)
14-
inline constexpr auto graph = R"IR(
15-
graph(%0 : Tensor,
16-
%1 : Tensor?,
17-
%2 : Tensor?,
18-
%3 : Tensor?,
19-
%4 : Tensor?,
20-
%5 : bool):
21-
%9 : bool = prim::Constant[value=0]()
22-
%6 : float = prim::Constant[value=0.10000000000000001]()
23-
%7 : float = prim::Constant[value=1.0000000000000001e-05]()
24-
%8 : Tensor = aten::instance_norm(%0, %1, %2, %3, %4, %5, %6, %7, %9)
25-
return (%8)
14+
constexpr auto graph = R"IR(
15+
graph(%input.1 : Tensor,
16+
%weight.1 : Tensor?,
17+
%bias.1 : Tensor?,
18+
%running_mean.1 : Tensor?,
19+
%running_var.1 : Tensor?,
20+
%use_input_stats.1 : bool):
21+
%cudnn_enabled.1 : bool = prim::Constant[value=1]()
22+
%momentum.1 : float = prim::Constant[value=0.10000000000000001]()
23+
%eps.1 : float = prim::Constant[value=1.0000000000000001e-05]()
24+
%4 : Tensor = aten::instance_norm(%input.1,
25+
%weight.1, %bias.1,
26+
%running_mean.1, %running_var.1,
27+
%use_input_stats.1, %momentum.1, %eps.1, %cudnn_enabled.1)
28+
return (%4)
2629
)IR";
2730

31+
2832
TEST(Converters, ATenInstanceNormConvertsCorrectly) {
2933
auto g = std::make_shared<torch::jit::Graph>();
3034
torch::jit::parseIR(graph, g.get());
3135

3236
auto in = at::randint(1, 10, {1, 5, 5, 5}, {at::kCUDA});
3337
torch::jit::IValue weight, bias, mean, var; // NoneType
34-
bool use_input_stats = true;
38+
// https://github.com/pytorch/pytorch/blob/79693bb86a3f601a5c0d3da52d99acec95bb48c1/torch/nn/modules/instancenorm.py#L59
39+
const bool use_input_stats = true;
40+
41+
auto trt_in = at::clone(in);
42+
torch::jit::IValue trt_weight, trt_bias, trt_mean, trt_var;
3543

3644
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {weight, bias, mean, var, use_input_stats});
3745
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in});
3846

39-
params = trtorch::core::conversion::get_named_params(g->inputs(), {weight, bias, mean, var, use_input_stats});
40-
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in});
47+
params = trtorch::core::conversion::get_named_params(g->inputs(), {trt_weight, trt_bias, trt_mean, trt_var, use_input_stats});
48+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in});
4149

4250
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
4351
}
@@ -50,37 +58,44 @@ TEST(Converters, ATenInstanceNormAffineConvertsCorrectly) {
5058

5159
auto weight = at::randn({in.size(1)}).to(at::kCUDA);
5260
auto bias = at::randn({in.size(1)}).to(at::kCUDA);
61+
62+
torch::jit::IValue mean, var; // NoneType
63+
const bool use_input_stats = true;
5364

54-
torch::jit::IValue mean, var; // NoneType
55-
bool use_input_stats = true;
65+
auto trt_in = at::clone(in);
66+
auto trt_weight = at::clone(weight);
67+
auto trt_bias = at::clone(bias);
68+
torch::jit::IValue trt_mean, trt_var;
5669

5770
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {weight, bias, mean, var, use_input_stats});
5871
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in});
5972

60-
params = trtorch::core::conversion::get_named_params(g->inputs(), {weight, bias, mean, var, use_input_stats});
61-
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in});
73+
params = trtorch::core::conversion::get_named_params(g->inputs(), {trt_weight, trt_bias, trt_mean, trt_var, use_input_stats});
74+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in});
6275

6376
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
6477
}
6578

66-
6779
TEST(Converters, ATenInstanceNormRunningStatsConvertsCorrectly) {
6880
auto g = std::make_shared<torch::jit::Graph>();
6981
torch::jit::parseIR(graph, g.get());
7082

71-
auto in = at::randint(1, 10, {1, 5, 5, 5}, {at::kCUDA});
72-
73-
torch::jit::IValue weight, bias; // NoneType
74-
75-
auto mean = at::randn({in.size(1)}).to(at::kCUDA);
76-
auto var = at::randn({in.size(1)}).to(at::kCUDA);
77-
bool use_input_stats = false;
83+
auto in = at::randn({1, 5, 5, 5}, {at::kCUDA});
84+
85+
torch::jit::IValue weight, bias;
86+
auto mean = at::zeros({in.size(1)}, {at::kCUDA});
87+
auto var = at::ones({in.size(1)}, {at::kCUDA});
88+
const bool use_input_stats = false;
89+
90+
auto trt_in = at::clone(in);
91+
torch::jit::IValue trt_weight, trt_bias;
92+
auto trt_mean = at::clone(mean);
93+
auto trt_var = at::clone(var);
7894

7995
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {weight, bias, mean, var, use_input_stats});
8096
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in});
8197

82-
params = trtorch::core::conversion::get_named_params(g->inputs(), {weight, bias, mean, var, use_input_stats});
83-
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in});
84-
98+
params = trtorch::core::conversion::get_named_params(g->inputs(), {trt_weight, trt_bias, trt_mean, trt_var, use_input_stats});
99+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in});
85100
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
86101
}

0 commit comments

Comments
 (0)