Skip to content

Commit 53f1e02

Browse files
Shixiaowei02kolinwei
authored andcommitted
fix infer crashes caused by conv/pool upgrades, test=release/1.6 (#20969)
* fix infer crashes caused by conv/pool upgrades, test=release/1.6 * fix bug, test=release/1.6
1 parent 6f0b2b1 commit 53f1e02

File tree

3 files changed

+59
-1
lines changed

3 files changed

+59
-1
lines changed

paddle/fluid/framework/op_compatible_info.cc

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,5 +215,50 @@ bool OpCompatibleMap::ReadFromProto(const proto::OpCompatibleMap& desc) {
215215
return true;
216216
}
217217

218+
bool ProgOptimUnsupported(std::shared_ptr<framework::ProgramDesc> program) {
219+
auto op_type_checker = [](const std::string& name) {
220+
const std::vector<std::string> op_types({
221+
"conv2d", "conv3d", "conv2d_transpose", "conv3d_transpose",
222+
"depthwise_conv2d", "depthwise_conv2d_transpose", "pool2d", "pool3d",
223+
});
224+
return std::find(op_types.begin(), op_types.end(), name) != op_types.end();
225+
};
226+
auto checker = [](const framework::OpDesc& op) {
227+
if (op.HasAttr("paddings") && op.HasAttr("strides")) {
228+
auto paddings = boost::get<std::vector<int>>(op.GetAttr("paddings"));
229+
auto strides = boost::get<std::vector<int>>(op.GetAttr("strides"));
230+
if (paddings.size() != strides.size()) {
231+
VLOG(3) << "== paddings size is not equal to strides size.";
232+
return true;
233+
}
234+
}
235+
if (op.HasAttr("data_format")) {
236+
auto data_format = boost::get<std::string>(op.GetAttr("data_format"));
237+
if (data_format == "NHWC" || data_format == "NDHWC") {
238+
VLOG(3) << "== data_format is NHWC or NDHWC.";
239+
return true;
240+
}
241+
}
242+
if (op.HasAttr("padding_algorithm")) {
243+
auto padding_algorithm =
244+
boost::get<std::string>(op.GetAttr("padding_algorithm"));
245+
if (padding_algorithm != "EXPLICIT") {
246+
VLOG(3) << "== padding_algorithm is not EXPLICIT.";
247+
return true;
248+
}
249+
}
250+
return false;
251+
};
252+
for (size_t i = 0; i < program->Size(); i++) {
253+
const auto& block = program->Block(i);
254+
for (auto* op : block.AllOps()) {
255+
if ((op_type_checker(op->Type())) && checker(*op)) {
256+
return true;
257+
}
258+
}
259+
}
260+
return false;
261+
}
262+
218263
} // namespace framework
219264
} // namespace paddle

paddle/fluid/framework/op_compatible_info.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
// limitations under the License.
1414

1515
#include <map>
16+
#include <memory>
1617
#include <string>
1718
#include "paddle/fluid/framework/program_desc.h"
1819

@@ -70,5 +71,9 @@ class OpCompatibleMap {
7071
std::string default_required_version_;
7172
};
7273

74+
// Determine if the model contains operators that the optimization cannot
75+
// support.
76+
bool ProgOptimUnsupported(std::shared_ptr<framework::ProgramDesc> program);
77+
7378
} // namespace framework
7479
} // namespace paddle

paddle/fluid/inference/api/analysis_predictor.cc

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ bool AnalysisPredictor::PrepareProgram(
145145
// still need to create other persistable variables.
146146
// So in both case, create persistable variables at first.
147147
if (!CheckOperatorCompatible()) {
148-
LOG(WARNING) << "WARNING: Results may be DIFF! "
148+
LOG(WARNING) << "WARNING: Results may be incorrect! "
149149
"Using same versions between model and lib.";
150150
}
151151
executor_->CreateVariables(*inference_program_, 0, true, sub_scope_);
@@ -458,6 +458,14 @@ void AnalysisPredictor::PrepareArgument() {
458458

459459
// NOTE All the members in AnalysisConfig should be copied to Argument.
460460
void AnalysisPredictor::OptimizeInferenceProgram() {
461+
if (ProgOptimUnsupported(inference_program_)) {
462+
LOG(INFO) << "NOTICE: Your inference model contains parameters such "
463+
"as asymmetric padding, and ir optimization is temporarily "
464+
"not supported, "
465+
"so it is turned off.";
466+
config_.SwitchIrOptim(false);
467+
argument_.SetEnableAnalysisOptim(false);
468+
}
461469
PrepareArgument();
462470
Analyzer().Run(&argument_);
463471

0 commit comments

Comments
 (0)