Skip to content

Commit 6476ce5

Browse files
authored
[Python API] add opt.enable_fp16 (#6048)
1 parent 6010b5f commit 6476ce5

File tree

4 files changed

+10
-9
lines changed

4 files changed

+10
-9
lines changed

lite/api/opt_base.cc

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -59,13 +59,12 @@ void OptBase::SetPassesInternal(
5959
opt_config_.set_passes_internal(passes_internal);
6060
}
6161

62-
void OptBase::SetValidPlaces(const std::string& valid_places,
63-
bool enable_fp16) {
62+
void OptBase::SetValidPlaces(const std::string& valid_places) {
6463
valid_places_.clear();
6564
auto target_reprs = lite::Split(valid_places, ",");
6665
for (auto& target_repr : target_reprs) {
6766
if (target_repr == "arm") {
68-
if (enable_fp16) {
67+
if (enable_fp16_) {
6968
valid_places_.emplace_back(
7069
Place{TARGET(kARM), PRECISION(kFP16), DATALAYOUT(kNCHW)});
7170
}
@@ -149,13 +148,12 @@ void OptBase::RunOptimize(const std::string& model_dir_path,
149148
const std::string& param_path,
150149
const std::string& model_type,
151150
const std::string& valid_places,
152-
const bool enable_fp16,
153151
const std::string& optimized_out_path) {
154152
SetModelDir(model_dir_path);
155153
SetModelFile(model_path);
156154
SetParamFile(param_path);
157155
SetModelType(model_type);
158-
SetValidPlaces(valid_places, enable_fp16);
156+
SetValidPlaces(valid_places);
159157
SetOptimizeOut(optimized_out_path);
160158
CheckIfModelSupported(false);
161159
OpKernelInfoCollector::Global().SetKernel2path(kernel2path_map);

lite/api/opt_base.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,8 @@ class LITE_API OptBase {
4747
void SetModelDir(const std::string &model_dir_path);
4848
void SetModelFile(const std::string &model_path);
4949
void SetParamFile(const std::string &param_path);
50-
void SetValidPlaces(const std::string &valid_places, bool enable_fp16);
50+
void EnableFloat16() { enable_fp16_ = true; }
51+
void SetValidPlaces(const std::string &valid_places);
5152
void SetOptimizeOut(const std::string &lite_out_name);
5253
void RecordModelInfo(bool record_strip_info = true);
5354
void SetQuantModel(bool quant_model);
@@ -64,7 +65,6 @@ class LITE_API OptBase {
6465
const std::string &param_path = "",
6566
const std::string &model_type = "",
6667
const std::string &valid_places = "",
67-
const bool enable_fp16 = false,
6868
const std::string &optimized_out_path = "");
6969
// fuctions of printing info
7070
// 1. help info
@@ -83,6 +83,7 @@ class LITE_API OptBase {
8383
void CheckIfModelSupported(bool print_ops_info = true);
8484

8585
private:
86+
bool enable_fp16_{false};
8687
CxxConfig opt_config_;
8788
// valid places for the optimized_model
8889
std::vector<Place> valid_places_;

lite/api/python/bin/paddle_lite_opt

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,9 +76,10 @@ def main():
7676
a.set_optimize_out(args.optimize_out)
7777
if args.valid_targets is not None:
7878
if args.enable_fp16 == "true":
79-
a.set_valid_places(args.valid_targets, True)
79+
a.enable_fp16()
80+
a.set_valid_places(args.valid_targets)
8081
else:
81-
a.set_valid_places(args.valid_targets, False)
82+
a.set_valid_places(args.valid_targets)
8283
if args.param_file is not None:
8384
a.set_param_file(args.param_file)
8485
if args.record_tailoring_info == "true":

lite/api/python/pybind/pybind.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ void BindLiteOpt(py::module *m) {
6464
.def("set_model_file", &OptBase::SetModelFile)
6565
.def("set_param_file", &OptBase::SetParamFile)
6666
.def("set_valid_places", &OptBase::SetValidPlaces)
67+
.def("enable_fp16", &OptBase::EnableFloat16)
6768
.def("set_optimize_out", &OptBase::SetOptimizeOut)
6869
.def("set_model_type", &OptBase::SetModelType)
6970
.def("set_quant_model", &OptBase::SetQuantModel)

0 commit comments

Comments
 (0)