Skip to content

Commit baec7a3

Browse files
authored
load inference model from memory buffer, test=release/1.7 (#22562)
* 1. load model from memory 2. scale is no longer added when saving inference model test=develop * raise ci coverage, test=develop * supports saving weights to memory. test=develop * raise ci coverage, test=develop * fix PADDLE_ENFORCE messages, test=develop
1 parent 4e3c535 commit baec7a3

File tree

5 files changed

+176
-85
lines changed

5 files changed

+176
-85
lines changed

paddle/fluid/operators/save_combine_op.cc

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,23 @@ to a file on disk.
7171
"The \"file_path\" where the LoDTensor variables will be saved.")
7272
.AddCustomChecker(
7373
[](const std::string& path) { return !path.empty(); });
74+
AddAttr<bool>("save_to_memory",
75+
"(boolean, default false)"
76+
"If true, the variables will be saved to binary strings.")
77+
.SetDefault(false);
78+
AddOutput("Y",
79+
"(RAW, default empty)."
80+
"This output is used when saving variables to binary strings.")
81+
.AsDispensable();
82+
}
83+
};
84+
85+
class SaveCombineOpInferVarType : public framework::VarTypeInference {
86+
public:
87+
void operator()(framework::InferVarTypeContext* ctx) const override {
88+
for (auto& o : ctx->Output("Y")) {
89+
ctx->SetType(o, framework::proto::VarType::RAW);
90+
}
7491
}
7592
};
7693

@@ -80,7 +97,7 @@ to a file on disk.
8097
namespace ops = paddle::operators;
8198

8299
REGISTER_OPERATOR(save_combine, ops::SaveCombineOp,
83-
ops::SaveCombineOpProtoMaker);
100+
ops::SaveCombineOpProtoMaker, ops::SaveCombineOpInferVarType);
84101

85102
REGISTER_OP_CPU_KERNEL(
86103
save_combine,

paddle/fluid/operators/save_combine_op.h

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -38,18 +38,16 @@ class SaveCombineOpKernel : public framework::OpKernel<T> {
3838
auto filename = ctx.Attr<std::string>("file_path");
3939
auto overwrite = ctx.Attr<bool>("overwrite");
4040
auto save_as_fp16 = ctx.Attr<bool>("save_as_fp16");
41+
auto save_to_memory = ctx.Attr<bool>("save_to_memory");
42+
auto output = ctx.Output<std::string>("Y");
4143

4244
bool is_present = FileExists(filename);
4345
if (is_present && !overwrite) {
4446
PADDLE_THROW("%s exists!, cannot save_combine to it when overwrite=false",
4547
filename, overwrite);
4648
}
4749

48-
MkDirRecursively(DirName(filename).c_str());
49-
std::ofstream fout(filename, std::ios::binary);
50-
PADDLE_ENFORCE(static_cast<bool>(fout), "Cannot open %s to write",
51-
filename);
52-
50+
std::ostringstream ss;
5351
auto inp_var_names = ctx.InputNames("X");
5452
auto &inp_vars = ctx.MultiInputVar("X");
5553
PADDLE_ENFORCE_GT(static_cast<int>(inp_var_names.size()), 0,
@@ -82,12 +80,25 @@ class SaveCombineOpKernel : public framework::OpKernel<T> {
8280
// copy LoD info to the new tensor
8381
out.set_lod(tensor.lod());
8482
framework::TransDataType(in_kernel_type, out_kernel_type, tensor, &out);
85-
framework::SerializeToStream(fout, out, dev_ctx);
83+
framework::SerializeToStream(ss, out, dev_ctx);
8684
} else {
87-
framework::SerializeToStream(fout, tensor, dev_ctx);
85+
framework::SerializeToStream(ss, tensor, dev_ctx);
8886
}
8987
}
90-
fout.close();
88+
if (save_to_memory) {
89+
PADDLE_ENFORCE_NE(output, nullptr,
90+
platform::errors::InvalidArgument(
91+
"Cannot find variable Y for save_combine_op"));
92+
*output = ss.str();
93+
} else {
94+
MkDirRecursively(DirName(filename).c_str());
95+
std::ofstream fout(filename, std::ios::binary);
96+
PADDLE_ENFORCE_EQ(
97+
static_cast<bool>(fout), true,
98+
platform::errors::NotFound("Cannot open %s to write", filename));
99+
fout << ss.str();
100+
fout.close();
101+
}
91102
}
92103
};
93104

paddle/fluid/pybind/pybind.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -915,6 +915,10 @@ All parameter, weight, gradient are variables in Paddle.
915915
return self.GetMutable<LoDTensor>();
916916
},
917917
py::return_value_policy::reference)
918+
.def("get_bytes",
919+
[](Variable &self) {
920+
return py::bytes(*self.GetMutable<std::string>());
921+
})
918922
.def("get_lod_rank_table",
919923
[](Variable &self) { return self.GetMutable<LoDRankTable>(); },
920924
py::return_value_policy::reference)

0 commit comments

Comments
 (0)