File tree Expand file tree Collapse file tree 3 files changed +13
-3
lines changed Expand file tree Collapse file tree 3 files changed +13
-3
lines changed Original file line number Diff line number Diff line change @@ -62,6 +62,12 @@ struct SimpleOpTypeSetTeller : public Teller {
62
62
63
63
bool OpTeller::Tell (const std::string& op_type, const framework::OpDesc& desc) {
64
64
for (auto & teller : tellers_) {
65
+ if (op_type == " pool2d" || op_type == " conv2d" ||
66
+ op_type == " depthwise_conv2d" || op_type == " conv2d_transpose" ) {
67
+ std::vector<int > paddings =
68
+ boost::get<std::vector<int >>(desc.GetAttr (" paddings" ));
69
+ if (paddings.size () > 2 ) return false ;
70
+ }
65
71
if ((*teller)(op_type, desc)) return true ;
66
72
}
67
73
return false ;
Original file line number Diff line number Diff line change @@ -123,7 +123,7 @@ bool AnalysisPredictor::PrepareScope(
123
123
status_is_cloned_ = true ;
124
124
} else {
125
125
if (config_.use_gpu_ ) {
126
- paddle::framework::InitDevices (false , {config_. device_id_ } );
126
+ paddle::framework::InitDevices (false );
127
127
} else {
128
128
paddle::framework::InitDevices (false , {});
129
129
}
@@ -500,8 +500,6 @@ std::unique_ptr<PaddlePredictor> CreatePaddlePredictor<
500
500
std::string flag = " --fraction_of_gpu_memory_to_use=" +
501
501
std::to_string (fraction_of_gpu_memory);
502
502
flags.push_back (flag);
503
- flags.push_back (" --selected_gpus=" +
504
- std::to_string (config.gpu_device_id ()));
505
503
VLOG (3 ) << " set flag: " << flag;
506
504
framework::InitGflags (flags);
507
505
}
Original file line number Diff line number Diff line change @@ -57,6 +57,12 @@ struct SimpleOpTypeSetTeller : public Teller {
57
57
58
58
bool OpTeller::Tell (const std::string& op_type, const framework::OpDesc& desc) {
59
59
for (auto & teller : tellers_) {
60
+ if (op_type == " pool2d" || op_type == " conv2d" ||
61
+ op_type == " depthwise_conv2d" || op_type == " conv2d_transpose" ) {
62
+ std::vector<int > paddings =
63
+ boost::get<std::vector<int >>(desc.GetAttr (" paddings" ));
64
+ if (paddings.size () > 2 ) return false ;
65
+ }
60
66
if ((*teller)(op_type, desc)) return true ;
61
67
}
62
68
return false ;
You can’t perform that action at this time.
0 commit comments