Skip to content

Commit af2a2e7

Browse files
DannyIsFunnySuperjomn
authored andcommitted
add the method of loading model from naive buffer for LightPredictor (#1918) (#1937)
1 parent 72c919d commit af2a2e7

File tree

12 files changed

+191
-25
lines changed

12 files changed

+191
-25
lines changed

lite/api/light_api.cc

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,18 +18,26 @@ namespace paddle {
1818
namespace lite {
1919

2020
void LightPredictor::Build(const std::string& model_dir,
21-
lite_api::LiteModelType model_type) {
21+
const std::string& model_buffer,
22+
const std::string& param_buffer,
23+
lite_api::LiteModelType model_type,
24+
bool model_from_memory) {
2225
cpp::ProgramDesc desc;
23-
LOG(INFO) << "Load model from " << model_dir;
2426
switch (model_type) {
2527
#ifndef LITE_ON_TINY_PUBLISH
2628
case lite_api::LiteModelType::kProtobuf:
2729
LoadModelPb(model_dir, "", "", scope_.get(), &desc);
2830
break;
2931
#endif
30-
case lite_api::LiteModelType::kNaiveBuffer:
31-
LoadModelNaive(model_dir, scope_.get(), &desc);
32+
case lite_api::LiteModelType::kNaiveBuffer: {
33+
if (model_from_memory) {
34+
LoadModelNaiveFromMemory(
35+
model_buffer, param_buffer, scope_.get(), &desc);
36+
} else {
37+
LoadModelNaive(model_dir, scope_.get(), &desc);
38+
}
3239
break;
40+
}
3341
default:
3442
LOG(FATAL) << "Unknown model type";
3543
}
@@ -83,11 +91,5 @@ void LightPredictor::BuildRuntimeProgram(const cpp::ProgramDesc& prog) {
8391
program_->set_exec_scope(program.exec_scope());
8492
}
8593

86-
LightPredictor::LightPredictor(const std::string& model_dir,
87-
lite_api::LiteModelType model_type) {
88-
scope_ = std::make_shared<Scope>();
89-
Build(model_dir, model_type);
90-
}
91-
9294
} // namespace lite
9395
} // namespace paddle

lite/api/light_api.h

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,15 @@ namespace lite {
3838
*/
3939
class LITE_API LightPredictor {
4040
public:
41-
explicit LightPredictor(
41+
LightPredictor(
4242
const std::string& model_dir,
43-
lite_api::LiteModelType model_type = lite_api::LiteModelType::kProtobuf);
43+
const std::string& model_buffer = "",
44+
const std::string& param_buffer = "",
45+
bool model_from_memory = false,
46+
lite_api::LiteModelType model_type = lite_api::LiteModelType::kProtobuf) {
47+
scope_ = std::make_shared<Scope>();
48+
Build(model_dir, model_buffer, param_buffer, model_type, model_from_memory);
49+
}
4450

4551
void Run() { program_->Run(); }
4652

@@ -58,7 +64,11 @@ class LITE_API LightPredictor {
5864
private:
5965
void Build(
6066
const std::string& model_dir,
61-
lite_api::LiteModelType model_type = lite_api::LiteModelType::kProtobuf);
67+
const std::string& model_buffer,
68+
const std::string& param_buffer,
69+
lite_api::LiteModelType model_type = lite_api::LiteModelType::kProtobuf,
70+
bool model_from_memory = false);
71+
6272
void BuildRuntimeProgram(const cpp::ProgramDesc& prog);
6373

6474
private:

lite/api/light_api_impl.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,9 @@ void LightPredictorImpl::Init(const MobileConfig& config) {
4545
lite::DeviceInfo::Global().SetRunMode(config.power_mode(), config.threads());
4646
#endif
4747
raw_predictor_.reset(new lite::LightPredictor(config.model_dir(),
48+
config.model_buffer(),
49+
config.param_buffer(),
50+
config.model_from_memory(),
4851
LiteModelType::kNaiveBuffer));
4952
}
5053

lite/api/light_api_test.cc

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,46 @@ TEST(LightAPI, load) {
2828
if (FLAGS_optimized_model.empty()) {
2929
FLAGS_optimized_model = "lite_naive_model";
3030
}
31-
LightPredictor predictor(FLAGS_optimized_model);
31+
LightPredictor predictor(FLAGS_optimized_model, "", "");
32+
auto* input_tensor = predictor.GetInput(0);
33+
input_tensor->Resize(DDim(std::vector<int64_t>({100, 100})));
34+
auto* data = input_tensor->mutable_data<float>();
35+
for (int i = 0; i < 100 * 100; i++) {
36+
data[i] = i;
37+
}
38+
39+
predictor.Run();
40+
41+
const auto* output = predictor.GetOutput(0);
42+
const float* raw_output = output->data<float>();
43+
44+
for (int i = 0; i < 10; i++) {
45+
LOG(INFO) << "out " << raw_output[i];
46+
}
47+
}
48+
49+
TEST(LightAPI, loadNaiveBuffer) {
50+
if (FLAGS_optimized_model.empty()) {
51+
FLAGS_optimized_model = "lite_naive_model";
52+
}
53+
54+
auto model_path = std::string(FLAGS_optimized_model) + "/__model__.nb";
55+
auto params_path = std::string(FLAGS_optimized_model) + "/param.nb";
56+
std::string model_buffer = lite::ReadFile(model_path);
57+
size_t size_model = model_buffer.length();
58+
std::string params_buffer = lite::ReadFile(params_path);
59+
size_t size_params = params_buffer.length();
60+
LOG(INFO) << "sizeModel: " << size_model;
61+
LOG(INFO) << "sizeParams: " << size_params;
62+
63+
lite_api::MobileConfig config;
64+
config.set_model_buffer(
65+
model_buffer.c_str(), size_model, params_buffer.c_str(), size_params);
66+
LightPredictor predictor(config.model_dir(),
67+
config.model_buffer(),
68+
config.param_buffer(),
69+
config.model_from_memory(),
70+
lite_api::LiteModelType::kNaiveBuffer);
3271

3372
auto* input_tensor = predictor.GetInput(0);
3473
input_tensor->Resize(DDim(std::vector<int64_t>({100, 100})));

lite/api/paddle_api.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,9 @@ class LITE_API CxxConfig : public ConfigBase {
129129
class LITE_API MobileConfig : public ConfigBase {
130130
PowerMode mode_{LITE_POWER_HIGH};
131131
int threads_{1};
132+
std::string model_buffer_;
133+
std::string param_buffer_;
134+
bool model_from_memory_{false};
132135

133136
public:
134137
MobileConfig(Place preferred_place = Place(TARGET(kARM),
@@ -139,9 +142,20 @@ class LITE_API MobileConfig : public ConfigBase {
139142
: mode_(mode), threads_(threads) {}
140143
void set_power_mode(PowerMode mode) { mode_ = mode; }
141144
void set_threads(int threads) { threads_ = threads; }
145+
void set_model_buffer(const char* model_buffer,
146+
size_t model_buffer_size,
147+
const char* param_buffer,
148+
size_t param_buffer_size) {
149+
model_buffer_ = std::string(model_buffer, model_buffer + model_buffer_size);
150+
param_buffer_ = std::string(param_buffer, param_buffer + param_buffer_size);
151+
model_from_memory_ = true;
152+
}
142153

143154
PowerMode power_mode() const { return mode_; }
144155
int threads() const { return threads_; }
156+
bool model_from_memory() const { return model_from_memory_; }
157+
const std::string& model_buffer() const { return model_buffer_; }
158+
const std::string& param_buffer() const { return param_buffer_; }
145159
};
146160

147161
template <typename ConfigT>

lite/api/paddle_api_test.cc

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
#include "lite/api/paddle_use_ops.h"
2020
#include "lite/api/paddle_use_passes.h"
2121
#include "lite/utils/cp_logging.h"
22-
22+
#include "lite/utils/io.h"
2323
DEFINE_string(model_dir, "", "");
2424

2525
namespace paddle {
@@ -58,6 +58,7 @@ TEST(CxxApi, run) {
5858
LiteModelType::kNaiveBuffer);
5959
}
6060

61+
// Demo1 for Mobile Devices :Load model from file and run
6162
#ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
6263
TEST(LightApi, run) {
6364
lite_api::MobileConfig config;
@@ -82,6 +83,39 @@ TEST(LightApi, run) {
8283
EXPECT_NEAR(out[0], 50.2132, 1e-3);
8384
EXPECT_NEAR(out[1], -28.8729, 1e-3);
8485
}
86+
87+
// Demo2 for Loading model from memory
88+
TEST(MobileConfig, LoadfromMemory) {
89+
// Get naive buffer
90+
auto model_path = std::string(FLAGS_model_dir) + ".opt2.naive/__model__.nb";
91+
auto params_path = std::string(FLAGS_model_dir) + ".opt2.naive/param.nb";
92+
std::string model_buffer = lite::ReadFile(model_path);
93+
size_t size_model = model_buffer.length();
94+
std::string params_buffer = lite::ReadFile(params_path);
95+
size_t size_params = params_buffer.length();
96+
// set model buffer and run model
97+
lite_api::MobileConfig config;
98+
config.set_model_buffer(
99+
model_buffer.c_str(), size_model, params_buffer.c_str(), size_params);
100+
101+
auto predictor = lite_api::CreatePaddlePredictor(config);
102+
auto input_tensor = predictor->GetInput(0);
103+
input_tensor->Resize(std::vector<int64_t>({100, 100}));
104+
auto* data = input_tensor->mutable_data<float>();
105+
for (int i = 0; i < 100 * 100; i++) {
106+
data[i] = i;
107+
}
108+
109+
predictor->Run();
110+
111+
const auto output = predictor->GetOutput(0);
112+
const float* raw_output = output->data<float>();
113+
114+
for (int i = 0; i < 10; i++) {
115+
LOG(INFO) << "out " << raw_output[i];
116+
}
117+
}
118+
85119
#endif
86120

87121
} // namespace lite_api

lite/model_parser/model_parser.cc

Lines changed: 43 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -661,9 +661,14 @@ void LoadParamNaive(const std::string &path,
661661

662662
void LoadCombinedParamsNaive(const std::string &path,
663663
lite::Scope *scope,
664-
const cpp::ProgramDesc &cpp_prog) {
664+
const cpp::ProgramDesc &cpp_prog,
665+
bool params_from_memory) {
665666
naive_buffer::BinaryTable table;
666-
table.LoadFromFile(path);
667+
if (params_from_memory) {
668+
table.LoadFromMemory(path.c_str(), path.length());
669+
} else {
670+
table.LoadFromFile(path);
671+
}
667672
naive_buffer::proto::CombinedParamsDesc pt_desc(&table);
668673
pt_desc.Load();
669674
naive_buffer::CombinedParamsDesc desc(&pt_desc);
@@ -710,7 +715,7 @@ void LoadModelNaive(const std::string &model_dir,
710715
// NOTE: Only main block be used now.
711716
if (combined) {
712717
const std::string combined_params_path = model_dir + "/param.nb";
713-
LoadCombinedParamsNaive(combined_params_path, scope, *cpp_prog);
718+
LoadCombinedParamsNaive(combined_params_path, scope, *cpp_prog, false);
714719
} else {
715720
auto &prog = *cpp_prog;
716721
auto &main_block_desc = *prog.GetBlock<cpp::BlockDesc>(0);
@@ -750,5 +755,40 @@ void LoadModelNaive(const std::string &model_dir,
750755
VLOG(4) << "Load naive buffer model in '" << model_dir << "' successfully";
751756
}
752757

758+
void LoadModelNaiveFromMemory(const std::string &model_buffer,
759+
const std::string &param_buffer,
760+
Scope *scope,
761+
cpp::ProgramDesc *cpp_prog) {
762+
CHECK(cpp_prog);
763+
CHECK(scope);
764+
cpp_prog->ClearBlocks();
765+
766+
// Load model
767+
768+
std::string prog_path = model_buffer;
769+
770+
naive_buffer::BinaryTable table;
771+
table.LoadFromMemory(prog_path.c_str(), prog_path.length());
772+
773+
naive_buffer::proto::ProgramDesc nb_proto_prog(&table);
774+
nb_proto_prog.Load();
775+
naive_buffer::ProgramDesc nb_prog(&nb_proto_prog);
776+
777+
// Transform to cpp::ProgramDesc
778+
TransformProgramDescAnyToCpp(nb_prog, cpp_prog);
779+
780+
// Load Params
781+
// NOTE: Only main block be used now.
782+
// only combined Params are supported in Loading Model from memory
783+
std::string combined_params_path = param_buffer;
784+
LoadCombinedParamsNaive(combined_params_path, scope, *cpp_prog, true);
785+
786+
#ifdef LITE_WITH_NPU
787+
LOG(FATAL) << "load from memory is not supported by NPU";
788+
#endif
789+
790+
VLOG(4) << "Load model from naive buffer memory successfully";
791+
}
792+
753793
} // namespace lite
754794
} // namespace paddle

lite/model_parser/model_parser.h

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -94,14 +94,15 @@ void LoadParamNaive(const std::string& path,
9494
lite::Scope* scope,
9595
const std::string& name);
9696

97-
void LoadCombinedParamsNaive(const std::string& path,
98-
lite::Scope* scope,
99-
const cpp::ProgramDesc& cpp_prog);
100-
10197
void LoadModelNaive(const std::string& model_dir,
10298
lite::Scope* scope,
10399
cpp::ProgramDesc* prog,
104100
bool combined = true);
105101

102+
void LoadModelNaiveFromMemory(const std::string& model_buffer,
103+
const std::string& param_buffer,
104+
lite::Scope* scope,
105+
cpp::ProgramDesc* cpp_prog);
106+
106107
} // namespace lite
107108
} // namespace paddle

lite/model_parser/model_parser_test.cc

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -121,12 +121,17 @@ TEST(ModelParser, SaveModelNaive) {
121121
SaveModelNaive(save_pb_model_path, scope, prog);
122122
}
123123

124-
TEST(ModelParser, LoadModelNaive) {
124+
TEST(ModelParser, LoadModelNaiveFromMemory) {
125125
CHECK(!FLAGS_model_dir.empty());
126126
cpp::ProgramDesc prog;
127127
Scope scope;
128-
const std::string model_path = FLAGS_model_dir + ".saved.naive";
129-
LoadModelNaive(model_path, &scope, &prog);
128+
129+
auto model_path = std::string(FLAGS_model_dir) + ".saved.naive/__model__.nb";
130+
auto params_path = std::string(FLAGS_model_dir) + ".saved.naive/param.nb";
131+
std::string model_buffer = lite::ReadFile(model_path);
132+
std::string params_buffer = lite::ReadFile(params_path);
133+
134+
LoadModelNaiveFromMemory(model_buffer, params_buffer, &scope, &prog);
130135
}
131136

132137
} // namespace lite

lite/model_parser/naive_buffer/naive_buffer.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,14 @@ void BinaryTable::LoadFromFile(const std::string &filename) {
6666
is_mutable_mode_ = false;
6767
}
6868

69+
void BinaryTable::LoadFromMemory(const char *buffer, size_t buffer_size) {
70+
// get buffer
71+
bytes_.resize(buffer_size);
72+
memcpy(reinterpret_cast<char *>(&bytes_[0]), buffer, buffer_size);
73+
// Set readonly.
74+
is_mutable_mode_ = false;
75+
}
76+
6977
void StringBuilder::Save() {
7078
// memory format: [size][string data]
7179
uint64_t mem_size = sizeof(uint64_t) + data_.size();

0 commit comments

Comments
 (0)