Skip to content

Commit 71cb612

Browse files
authored
[Fix] fix the issue with batch norm mapper (#1550)
1 parent 1375b35 commit 71cb612

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

paddle2onnx/mapper/nn/batch_norm.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ void BatchNormMapper::Opset14() {
103103

104104
std::vector<std::string> output_names;
105105
output_names.push_back(output_info[0].name);
106-
if (!use_global_stats_) {
106+
if (trainable_statistics_) {
107107
output_names.push_back(mean_out_info[0].name);
108108
output_names.push_back(variance_out_info[0].name);
109109
}
@@ -121,7 +121,8 @@ void BatchNormMapper::Opset14() {
121121

122122
AddAttribute(node, "epsilon", epsilon_);
123123
AddAttribute(node, "momentum", momentum_);
124-
AddAttribute(node, "training_mode", static_cast<int64_t>(!use_global_stats_));
124+
AddAttribute(
125+
node, "training_mode", static_cast<int64_t>(trainable_statistics_));
125126
}
126127

127128
} // namespace paddle2onnx

0 commit comments

Comments
 (0)