Skip to content

Commit 171df5b

Browse files
authored
Merge pull request #16303 from junjun315/checkpoint
for Checkpoint save and load
2 parents e3bca90 + ac32bf6 commit 171df5b

19 files changed

+1014
-352
lines changed

paddle/fluid/operators/load_combine_op.cc

Lines changed: 30 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -11,89 +11,27 @@ distributed under the License is distributed on an "AS IS" BASIS,
1111
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
See the License for the specific language governing permissions and
1313
limitations under the License. */
14-
#include <fstream>
15-
#include "paddle/fluid/framework/data_type_transform.h"
16-
#include "paddle/fluid/framework/op_registry.h"
17-
#include "paddle/fluid/platform/device_context.h"
14+
15+
#include <string>
16+
#include <vector>
17+
18+
#include "paddle/fluid/operators/load_combine_op.h"
1819

1920
namespace paddle {
2021
namespace operators {
2122

22-
class LoadCombineOp : public framework::OperatorBase {
23+
class LoadCombineOp : public framework::OperatorWithKernel {
2324
public:
24-
LoadCombineOp(const std::string &type,
25-
const framework::VariableNameMap &inputs,
26-
const framework::VariableNameMap &outputs,
27-
const framework::AttributeMap &attrs)
28-
: OperatorBase(type, inputs, outputs, attrs) {}
29-
30-
private:
31-
void RunImpl(const framework::Scope &scope,
32-
const platform::Place &place) const override {
33-
auto filename = Attr<std::string>("file_path");
34-
auto load_as_fp16 = Attr<bool>("load_as_fp16");
35-
auto model_from_memory = Attr<bool>("model_from_memory");
36-
auto out_var_names = Outputs("Out");
37-
PADDLE_ENFORCE_GT(
38-
static_cast<int>(out_var_names.size()), 0,
39-
"The number of output variables should be greater than 0.");
40-
if (!model_from_memory) {
41-
std::ifstream fin(filename, std::ios::binary);
42-
PADDLE_ENFORCE(static_cast<bool>(fin),
43-
"Cannot open file %s for load_combine op", filename);
44-
LoadParamsFromBuffer(scope, place, &fin, load_as_fp16, out_var_names);
45-
} else {
46-
PADDLE_ENFORCE(!filename.empty(), "Cannot load file from memory");
47-
std::stringstream fin(filename, std::ios::in | std::ios::binary);
48-
LoadParamsFromBuffer(scope, place, &fin, load_as_fp16, out_var_names);
49-
}
50-
}
51-
void LoadParamsFromBuffer(
52-
const framework::Scope &scope, const platform::Place &place,
53-
std::istream *buffer, bool load_as_fp16,
54-
const std::vector<std::string> &out_var_names) const {
55-
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
56-
auto &dev_ctx = *pool.Get(place);
57-
58-
for (size_t i = 0; i < out_var_names.size(); i++) {
59-
auto *out_var = scope.FindVar(out_var_names[i]);
60-
61-
PADDLE_ENFORCE(out_var != nullptr, "Output variable %s cannot be found",
62-
out_var_names[i]);
63-
64-
auto *tensor = out_var->GetMutable<framework::LoDTensor>();
65-
66-
// Error checking
67-
PADDLE_ENFORCE(static_cast<bool>(*buffer), "Cannot read more");
68-
69-
// Get data from fin to tensor
70-
DeserializeFromStream(*buffer, tensor, dev_ctx);
71-
72-
auto in_dtype = tensor->type();
73-
auto out_dtype =
74-
load_as_fp16 ? framework::proto::VarType::FP16 : in_dtype;
75-
76-
if (in_dtype != out_dtype) {
77-
// convert to float16 tensor
78-
auto in_kernel_type = framework::OpKernelType(in_dtype, place);
79-
auto out_kernel_type = framework::OpKernelType(out_dtype, place);
80-
framework::LoDTensor fp16_tensor;
81-
// copy LoD info to the new tensor
82-
fp16_tensor.set_lod(tensor->lod());
83-
framework::TransDataType(in_kernel_type, out_kernel_type, *tensor,
84-
&fp16_tensor);
85-
86-
// reset output tensor
87-
out_var->Clear();
88-
tensor = out_var->GetMutable<framework::LoDTensor>();
89-
tensor->set_lod(fp16_tensor.lod());
90-
tensor->ShareDataWith(fp16_tensor);
91-
}
92-
}
93-
buffer->peek();
94-
PADDLE_ENFORCE(buffer->eof(),
95-
"You are not allowed to load partial data via "
96-
"load_combine_op, use load_op instead.");
25+
using framework::OperatorWithKernel::OperatorWithKernel;
26+
27+
void InferShape(framework::InferShapeContext *ctx) const override {}
28+
29+
protected:
30+
framework::OpKernelType GetExpectedKernelType(
31+
const framework::ExecutionContext &ctx) const override {
32+
framework::OpKernelType kt = framework::OpKernelType(
33+
framework::proto::VarType::FP32, ctx.GetPlace());
34+
return kt;
9735
}
9836
};
9937

@@ -124,21 +62,30 @@ class LoadCombineOpProtoMaker : public framework::OpProtoAndCheckerMaker {
12462
AddComment(R"DOC(
12563
LoadCombine Operator.
12664
127-
LoadCombine operator loads LoDTensor variables from a file, which could be
128-
loaded in memory already. The file should contain one or more LoDTensors
65+
LoadCombine operator loads LoDTensor variables from a file, which could be
66+
loaded in memory already. The file should contain one or more LoDTensors
12967
serialized using the SaveCombine operator. The
130-
LoadCombine operator applies a deserialization strategy to appropriately load
131-
the LodTensors, and this strategy complements the serialization strategy used
68+
LoadCombine operator applies a deserialization strategy to appropriately load
69+
the LodTensors, and this strategy complements the serialization strategy used
13270
in the SaveCombine operator. Hence, the LoadCombine operator is tightly coupled
133-
with the SaveCombine operator, and can only deserialize one or more LoDTensors
71+
with the SaveCombine operator, and can only deserialize one or more LoDTensors
13472
that were saved using the SaveCombine operator.
13573
13674
)DOC");
13775
}
13876
};
77+
13978
} // namespace operators
14079
} // namespace paddle
80+
14181
namespace ops = paddle::operators;
14282

14383
REGISTER_OPERATOR(load_combine, ops::LoadCombineOp,
14484
ops::LoadCombineOpProtoMaker);
85+
86+
REGISTER_OP_CPU_KERNEL(
87+
load_combine,
88+
ops::LoadCombineOpKernel<paddle::platform::CPUDeviceContext, float>,
89+
ops::LoadCombineOpKernel<paddle::platform::CPUDeviceContext, double>,
90+
ops::LoadCombineOpKernel<paddle::platform::CPUDeviceContext, int>,
91+
ops::LoadCombineOpKernel<paddle::platform::CPUDeviceContext, int64_t>);
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/operators/load_combine_op.h"
16+
17+
namespace ops = paddle::operators;
18+
19+
REGISTER_OP_CUDA_KERNEL(
20+
load_combine,
21+
ops::LoadCombineOpKernel<paddle::platform::CUDADeviceContext, float>,
22+
ops::LoadCombineOpKernel<paddle::platform::CUDADeviceContext, double>,
23+
ops::LoadCombineOpKernel<paddle::platform::CUDADeviceContext, int>,
24+
ops::LoadCombineOpKernel<paddle::platform::CUDADeviceContext, int8_t>,
25+
ops::LoadCombineOpKernel<paddle::platform::CUDADeviceContext, int64_t>);
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
17+
#include <fstream>
18+
#include <string>
19+
#include <vector>
20+
21+
#include "paddle/fluid/framework/data_type.h"
22+
#include "paddle/fluid/framework/data_type_transform.h"
23+
#include "paddle/fluid/framework/op_registry.h"
24+
#include "paddle/fluid/platform/device_context.h"
25+
26+
namespace paddle {
27+
namespace operators {
28+
template <typename DeviceContext, typename T>
29+
class LoadCombineOpKernel : public framework::OpKernel<T> {
30+
public:
31+
void Compute(const framework::ExecutionContext &ctx) const override {
32+
auto place = ctx.GetPlace();
33+
auto filename = ctx.Attr<std::string>("file_path");
34+
auto load_as_fp16 = ctx.Attr<bool>("load_as_fp16");
35+
auto model_from_memory = ctx.Attr<bool>("model_from_memory");
36+
auto &out_var_names = ctx.Outputs("Out");
37+
38+
PADDLE_ENFORCE_GT(
39+
static_cast<int>(out_var_names.size()), 0,
40+
"The number of output variables should be greater than 0.");
41+
if (!model_from_memory) {
42+
std::ifstream fin(filename, std::ios::binary);
43+
PADDLE_ENFORCE(static_cast<bool>(fin),
44+
"Cannot open file %s for load_combine op", filename);
45+
LoadParamsFromBuffer(ctx, place, &fin, load_as_fp16, out_var_names);
46+
} else {
47+
PADDLE_ENFORCE(!filename.empty(), "Cannot load file from memory");
48+
std::stringstream fin(filename, std::ios::in | std::ios::binary);
49+
LoadParamsFromBuffer(ctx, place, &fin, load_as_fp16, out_var_names);
50+
}
51+
}
52+
53+
void LoadParamsFromBuffer(
54+
const framework::ExecutionContext &context, const platform::Place &place,
55+
std::istream *buffer, bool load_as_fp16,
56+
const std::vector<std::string> &out_var_names) const {
57+
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
58+
auto &dev_ctx = *pool.Get(place);
59+
auto out_vars = context.MultiOutputVar("Out");
60+
61+
for (size_t i = 0; i < out_var_names.size(); i++) {
62+
PADDLE_ENFORCE(out_vars[i] != nullptr,
63+
"Output variable %s cannot be found", out_var_names[i]);
64+
65+
auto *tensor = out_vars[i]->GetMutable<framework::LoDTensor>();
66+
67+
// Error checking
68+
PADDLE_ENFORCE(static_cast<bool>(*buffer), "Cannot read more");
69+
70+
// Get data from fin to tensor
71+
DeserializeFromStream(*buffer, tensor, dev_ctx);
72+
73+
auto in_dtype = tensor->type();
74+
auto out_dtype =
75+
load_as_fp16 ? framework::proto::VarType::FP16 : in_dtype;
76+
77+
if (in_dtype != out_dtype) {
78+
// convert to float16 tensor
79+
auto in_kernel_type = framework::OpKernelType(in_dtype, place);
80+
auto out_kernel_type = framework::OpKernelType(out_dtype, place);
81+
framework::LoDTensor fp16_tensor;
82+
// copy LoD info to the new tensor
83+
fp16_tensor.set_lod(tensor->lod());
84+
framework::TransDataType(in_kernel_type, out_kernel_type, *tensor,
85+
&fp16_tensor);
86+
87+
// reset output tensor
88+
out_vars[i]->Clear();
89+
tensor = out_vars[i]->GetMutable<framework::LoDTensor>();
90+
tensor->set_lod(fp16_tensor.lod());
91+
tensor->ShareDataWith(fp16_tensor);
92+
}
93+
}
94+
buffer->peek();
95+
PADDLE_ENFORCE(buffer->eof(),
96+
"You are not allowed to load partial data via "
97+
"load_combine_op, use load_op instead.");
98+
}
99+
};
100+
101+
} // namespace operators
102+
} // namespace paddle

paddle/fluid/operators/load_op.cc

Lines changed: 19 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -11,89 +11,26 @@ distributed under the License is distributed on an "AS IS" BASIS,
1111
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
See the License for the specific language governing permissions and
1313
limitations under the License. */
14-
#include <fstream>
1514

16-
#include "paddle/fluid/framework/data_type_transform.h"
17-
#include "paddle/fluid/framework/op_registry.h"
18-
#include "paddle/fluid/platform/device_context.h"
19-
#include "paddle/fluid/platform/profiler.h"
15+
#include <string>
16+
17+
#include "paddle/fluid/operators/load_op.h"
2018

2119
namespace paddle {
2220
namespace operators {
2321

24-
class LoadOp : public framework::OperatorBase {
22+
class LoadOp : public framework::OperatorWithKernel {
2523
public:
26-
LoadOp(const std::string &type, const framework::VariableNameMap &inputs,
27-
const framework::VariableNameMap &outputs,
28-
const framework::AttributeMap &attrs)
29-
: OperatorBase(type, inputs, outputs, attrs) {}
30-
31-
private:
32-
void RunImpl(const framework::Scope &scope,
33-
const platform::Place &place) const override {
34-
// FIXME(yuyang18): We save variable to local file now, but we should change
35-
// it to save an output stream.
36-
auto filename = Attr<std::string>("file_path");
37-
std::ifstream fin(filename, std::ios::binary);
38-
PADDLE_ENFORCE(static_cast<bool>(fin), "Cannot open file %s for load op",
39-
filename);
24+
using framework::OperatorWithKernel::OperatorWithKernel;
4025

41-
auto out_var_name = Output("Out");
42-
auto *out_var = scope.FindVar(out_var_name);
43-
PADDLE_ENFORCE(out_var != nullptr,
44-
"Output variable %s cannot be found in scope %p",
45-
out_var_name, &scope);
26+
void InferShape(framework::InferShapeContext *ctx) const override {}
4627

47-
if (out_var->IsType<framework::LoDTensor>()) {
48-
LoadLodTensor(fin, place, out_var);
49-
} else if (out_var->IsType<framework::SelectedRows>()) {
50-
LoadSelectedRows(fin, place, out_var);
51-
} else {
52-
PADDLE_ENFORCE(
53-
false,
54-
"Load only support LoDTensor and SelectedRows, %s has wrong type",
55-
out_var_name);
56-
}
57-
}
58-
59-
void LoadLodTensor(std::istream &fin, const platform::Place &place,
60-
framework::Variable *var) const {
61-
// get device context from pool
62-
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
63-
auto &dev_ctx = *pool.Get(place);
64-
auto *tensor = var->GetMutable<framework::LoDTensor>();
65-
DeserializeFromStream(fin, tensor, dev_ctx);
66-
67-
auto load_as_fp16 = Attr<bool>("load_as_fp16");
68-
auto in_dtype = tensor->type();
69-
auto out_dtype = load_as_fp16 ? framework::proto::VarType::FP16 : in_dtype;
70-
71-
if (in_dtype != out_dtype) {
72-
// convert to float16 tensor
73-
auto in_kernel_type = framework::OpKernelType(in_dtype, place);
74-
auto out_kernel_type = framework::OpKernelType(out_dtype, place);
75-
framework::LoDTensor fp16_tensor;
76-
// copy LoD info to the new tensor
77-
fp16_tensor.set_lod(tensor->lod());
78-
framework::TransDataType(in_kernel_type, out_kernel_type, *tensor,
79-
&fp16_tensor);
80-
81-
// reset output tensor
82-
var->Clear();
83-
tensor = var->GetMutable<framework::LoDTensor>();
84-
tensor->set_lod(fp16_tensor.lod());
85-
tensor->ShareDataWith(fp16_tensor);
86-
}
87-
}
88-
89-
void LoadSelectedRows(std::istream &fin, const platform::Place &place,
90-
framework::Variable *var) const {
91-
auto *selectedRows = var->GetMutable<framework::SelectedRows>();
92-
// get device context from pool
93-
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
94-
auto &dev_ctx = *pool.Get(place);
95-
framework::DeserializeFromStream(fin, selectedRows, dev_ctx);
96-
selectedRows->SyncIndex();
28+
protected:
29+
framework::OpKernelType GetExpectedKernelType(
30+
const framework::ExecutionContext &ctx) const override {
31+
framework::OpKernelType kt = framework::OpKernelType(
32+
framework::proto::VarType::FP32, platform::CPUPlace());
33+
return kt;
9734
}
9835
};
9936

@@ -116,8 +53,15 @@ class LoadOpProtoMaker : public framework::OpProtoAndCheckerMaker {
11653
"file.");
11754
}
11855
};
56+
11957
} // namespace operators
12058
} // namespace paddle
12159
namespace ops = paddle::operators;
12260

12361
REGISTER_OPERATOR(load, ops::LoadOp, ops::LoadOpProtoMaker);
62+
63+
REGISTER_OP_CPU_KERNEL(
64+
load, ops::LoadOpKernel<paddle::platform::CPUDeviceContext, float>,
65+
ops::LoadOpKernel<paddle::platform::CPUDeviceContext, double>,
66+
ops::LoadOpKernel<paddle::platform::CPUDeviceContext, int>,
67+
ops::LoadOpKernel<paddle::platform::CPUDeviceContext, int64_t>);

0 commit comments

Comments
 (0)