Skip to content

Commit 622fe6a

Browse files
committed
checkpoint pr be moved here, test=develop
1 parent bed0ecf commit 622fe6a

File tree

12 files changed

+543
-343
lines changed

12 files changed

+543
-343
lines changed

paddle/fluid/operators/load_combine_op.cc

Lines changed: 6 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,11 @@ distributed under the License is distributed on an "AS IS" BASIS,
1111
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
See the License for the specific language governing permissions and
1313
limitations under the License. */
14-
#include <fstream>
15-
#include "paddle/fluid/framework/data_type.h"
16-
#include "paddle/fluid/framework/data_type_transform.h"
17-
#include "paddle/fluid/framework/op_registry.h"
18-
#include "paddle/fluid/platform/device_context.h"
14+
15+
#include <string>
16+
#include <vector>
17+
18+
#include "paddle/fluid/operators/load_combine_op.h"
1919

2020
namespace paddle {
2121
namespace operators {
@@ -30,7 +30,7 @@ class LoadCombineOp : public framework::OperatorWithKernel {
3030
framework::OpKernelType GetExpectedKernelType(
3131
const framework::ExecutionContext &ctx) const override {
3232
framework::OpKernelType kt = framework::OpKernelType(
33-
framework::proto::VarType::FP32, platform::CPUPlace());
33+
framework::proto::VarType::FP32, ctx.GetPlace());
3434
return kt;
3535
}
3636
};
@@ -75,79 +75,6 @@ that were saved using the SaveCombine operator.
7575
}
7676
};
7777

78-
template <typename DeviceContext, typename T>
79-
class LoadCombineOpKernel : public framework::OpKernel<T> {
80-
public:
81-
void Compute(const framework::ExecutionContext &ctx) const override {
82-
auto place = ctx.GetPlace();
83-
auto filename = ctx.Attr<std::string>("file_path");
84-
auto load_as_fp16 = ctx.Attr<bool>("load_as_fp16");
85-
auto model_from_memory = ctx.Attr<bool>("model_from_memory");
86-
auto &out_var_names = ctx.Outputs("Out");
87-
88-
PADDLE_ENFORCE_GT(
89-
static_cast<int>(out_var_names.size()), 0,
90-
"The number of output variables should be greater than 0.");
91-
if (!model_from_memory) {
92-
std::ifstream fin(filename, std::ios::binary);
93-
PADDLE_ENFORCE(static_cast<bool>(fin),
94-
"Cannot open file %s for load_combine op", filename);
95-
LoadParamsFromBuffer(ctx, place, &fin, load_as_fp16, out_var_names);
96-
} else {
97-
PADDLE_ENFORCE(!filename.empty(), "Cannot load file from memory");
98-
std::stringstream fin(filename, std::ios::in | std::ios::binary);
99-
LoadParamsFromBuffer(ctx, place, &fin, load_as_fp16, out_var_names);
100-
}
101-
}
102-
103-
void LoadParamsFromBuffer(
104-
const framework::ExecutionContext &context, const platform::Place &place,
105-
std::istream *buffer, bool load_as_fp16,
106-
const std::vector<std::string> &out_var_names) const {
107-
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
108-
auto &dev_ctx = *pool.Get(place);
109-
auto out_vars = context.MultiOutputVar("Out");
110-
111-
for (size_t i = 0; i < out_var_names.size(); i++) {
112-
PADDLE_ENFORCE(out_vars[i] != nullptr,
113-
"Output variable %s cannot be found", out_var_names[i]);
114-
115-
auto *tensor = out_vars[i]->GetMutable<framework::LoDTensor>();
116-
117-
// Error checking
118-
PADDLE_ENFORCE(static_cast<bool>(*buffer), "Cannot read more");
119-
120-
// Get data from fin to tensor
121-
DeserializeFromStream(*buffer, tensor, dev_ctx);
122-
123-
auto in_dtype = tensor->type();
124-
auto out_dtype =
125-
load_as_fp16 ? framework::proto::VarType::FP16 : in_dtype;
126-
127-
if (in_dtype != out_dtype) {
128-
// convert to float16 tensor
129-
auto in_kernel_type = framework::OpKernelType(in_dtype, place);
130-
auto out_kernel_type = framework::OpKernelType(out_dtype, place);
131-
framework::LoDTensor fp16_tensor;
132-
// copy LoD info to the new tensor
133-
fp16_tensor.set_lod(tensor->lod());
134-
framework::TransDataType(in_kernel_type, out_kernel_type, *tensor,
135-
&fp16_tensor);
136-
137-
// reset output tensor
138-
out_vars[i]->Clear();
139-
tensor = out_vars[i]->GetMutable<framework::LoDTensor>();
140-
tensor->set_lod(fp16_tensor.lod());
141-
tensor->ShareDataWith(fp16_tensor);
142-
}
143-
}
144-
buffer->peek();
145-
PADDLE_ENFORCE(buffer->eof(),
146-
"You are not allowed to load partial data via "
147-
"load_combine_op, use load_op instead.");
148-
}
149-
};
150-
15178
} // namespace operators
15279
} // namespace paddle
15380

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/operators/load_combine_op.h"
16+
17+
namespace ops = paddle::operators;
18+
19+
REGISTER_OP_CUDA_KERNEL(
20+
load_combine,
21+
ops::LoadCombineOpKernel<paddle::platform::CUDADeviceContext, float>,
22+
ops::LoadCombineOpKernel<paddle::platform::CUDADeviceContext, double>,
23+
ops::LoadCombineOpKernel<paddle::platform::CUDADeviceContext, int>,
24+
ops::LoadCombineOpKernel<paddle::platform::CUDADeviceContext, int64_t>);
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
17+
#include <fstream>
18+
#include <string>
19+
#include <vector>
20+
21+
#include "paddle/fluid/framework/data_type.h"
22+
#include "paddle/fluid/framework/data_type_transform.h"
23+
#include "paddle/fluid/framework/op_registry.h"
24+
#include "paddle/fluid/platform/device_context.h"
25+
26+
namespace paddle {
27+
namespace operators {
28+
template <typename DeviceContext, typename T>
29+
class LoadCombineOpKernel : public framework::OpKernel<T> {
30+
public:
31+
void Compute(const framework::ExecutionContext &ctx) const override {
32+
auto place = ctx.GetPlace();
33+
auto filename = ctx.Attr<std::string>("file_path");
34+
auto load_as_fp16 = ctx.Attr<bool>("load_as_fp16");
35+
auto model_from_memory = ctx.Attr<bool>("model_from_memory");
36+
auto &out_var_names = ctx.Outputs("Out");
37+
38+
PADDLE_ENFORCE_GT(
39+
static_cast<int>(out_var_names.size()), 0,
40+
"The number of output variables should be greater than 0.");
41+
if (!model_from_memory) {
42+
std::ifstream fin(filename, std::ios::binary);
43+
PADDLE_ENFORCE(static_cast<bool>(fin),
44+
"Cannot open file %s for load_combine op", filename);
45+
LoadParamsFromBuffer(ctx, place, &fin, load_as_fp16, out_var_names);
46+
} else {
47+
PADDLE_ENFORCE(!filename.empty(), "Cannot load file from memory");
48+
std::stringstream fin(filename, std::ios::in | std::ios::binary);
49+
LoadParamsFromBuffer(ctx, place, &fin, load_as_fp16, out_var_names);
50+
}
51+
}
52+
53+
void LoadParamsFromBuffer(
54+
const framework::ExecutionContext &context, const platform::Place &place,
55+
std::istream *buffer, bool load_as_fp16,
56+
const std::vector<std::string> &out_var_names) const {
57+
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
58+
auto &dev_ctx = *pool.Get(place);
59+
auto out_vars = context.MultiOutputVar("Out");
60+
61+
for (size_t i = 0; i < out_var_names.size(); i++) {
62+
PADDLE_ENFORCE(out_vars[i] != nullptr,
63+
"Output variable %s cannot be found", out_var_names[i]);
64+
65+
auto *tensor = out_vars[i]->GetMutable<framework::LoDTensor>();
66+
67+
// Error checking
68+
PADDLE_ENFORCE(static_cast<bool>(*buffer), "Cannot read more");
69+
70+
// Get data from fin to tensor
71+
DeserializeFromStream(*buffer, tensor, dev_ctx);
72+
73+
auto in_dtype = tensor->type();
74+
auto out_dtype =
75+
load_as_fp16 ? framework::proto::VarType::FP16 : in_dtype;
76+
77+
if (in_dtype != out_dtype) {
78+
// convert to float16 tensor
79+
auto in_kernel_type = framework::OpKernelType(in_dtype, place);
80+
auto out_kernel_type = framework::OpKernelType(out_dtype, place);
81+
framework::LoDTensor fp16_tensor;
82+
// copy LoD info to the new tensor
83+
fp16_tensor.set_lod(tensor->lod());
84+
framework::TransDataType(in_kernel_type, out_kernel_type, *tensor,
85+
&fp16_tensor);
86+
87+
// reset output tensor
88+
out_vars[i]->Clear();
89+
tensor = out_vars[i]->GetMutable<framework::LoDTensor>();
90+
tensor->set_lod(fp16_tensor.lod());
91+
tensor->ShareDataWith(fp16_tensor);
92+
}
93+
}
94+
buffer->peek();
95+
PADDLE_ENFORCE(buffer->eof(),
96+
"You are not allowed to load partial data via "
97+
"load_combine_op, use load_op instead.");
98+
}
99+
};
100+
101+
} // namespace operators
102+
} // namespace paddle

paddle/fluid/operators/load_op.cc

Lines changed: 3 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,10 @@ distributed under the License is distributed on an "AS IS" BASIS,
1111
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
See the License for the specific language governing permissions and
1313
limitations under the License. */
14-
#include <fstream>
1514

16-
#include "paddle/fluid/framework/data_type_transform.h"
17-
#include "paddle/fluid/framework/op_registry.h"
18-
#include "paddle/fluid/platform/device_context.h"
19-
#include "paddle/fluid/platform/profiler.h"
15+
#include <string>
16+
17+
#include "paddle/fluid/operators/load_op.h"
2018

2119
namespace paddle {
2220
namespace operators {
@@ -56,80 +54,6 @@ class LoadOpProtoMaker : public framework::OpProtoAndCheckerMaker {
5654
}
5755
};
5856

59-
template <typename DeviceContext, typename T>
60-
class LoadOpKernel : public framework::OpKernel<T> {
61-
public:
62-
void Compute(const framework::ExecutionContext &ctx) const override {
63-
auto place = ctx.GetPlace();
64-
// FIXME(yuyang18): We save variable to local file now, but we should change
65-
// it to save an output stream.
66-
auto filename = ctx.Attr<std::string>("file_path");
67-
std::ifstream fin(filename, std::ios::binary);
68-
PADDLE_ENFORCE(static_cast<bool>(fin), "Cannot open file %s for load op",
69-
filename);
70-
71-
auto out_var_name = ctx.Outputs("Out").data();
72-
auto *out_var = ctx.OutputVar("Out");
73-
74-
PADDLE_ENFORCE(out_var != nullptr, "Output variable %s cannot be found ",
75-
out_var_name);
76-
77-
PADDLE_ENFORCE(out_var != nullptr, "Output variable cannot be found ");
78-
79-
if (out_var->IsType<framework::LoDTensor>()) {
80-
LoadLodTensor(fin, place, out_var, ctx);
81-
} else if (out_var->IsType<framework::SelectedRows>()) {
82-
LoadSelectedRows(fin, place, out_var);
83-
} else {
84-
PADDLE_ENFORCE(
85-
false,
86-
"Load only support LoDTensor and SelectedRows, %s has wrong type",
87-
out_var_name);
88-
}
89-
}
90-
91-
void LoadLodTensor(std::istream &fin, const platform::Place &place,
92-
framework::Variable *var,
93-
const framework::ExecutionContext &ctx) const {
94-
// get device context from pool
95-
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
96-
auto &dev_ctx = *pool.Get(place);
97-
auto *tensor = var->GetMutable<framework::LoDTensor>();
98-
DeserializeFromStream(fin, tensor, dev_ctx);
99-
100-
auto load_as_fp16 = ctx.Attr<bool>("load_as_fp16");
101-
auto in_dtype = tensor->type();
102-
auto out_dtype = load_as_fp16 ? framework::proto::VarType::FP16 : in_dtype;
103-
104-
if (in_dtype != out_dtype) {
105-
// convert to float16 tensor
106-
auto in_kernel_type = framework::OpKernelType(in_dtype, place);
107-
auto out_kernel_type = framework::OpKernelType(out_dtype, place);
108-
framework::LoDTensor fp16_tensor;
109-
// copy LoD info to the new tensor
110-
fp16_tensor.set_lod(tensor->lod());
111-
framework::TransDataType(in_kernel_type, out_kernel_type, *tensor,
112-
&fp16_tensor);
113-
114-
// reset output tensor
115-
var->Clear();
116-
tensor = var->GetMutable<framework::LoDTensor>();
117-
tensor->set_lod(fp16_tensor.lod());
118-
tensor->ShareDataWith(fp16_tensor);
119-
}
120-
}
121-
122-
void LoadSelectedRows(std::istream &fin, const platform::Place &place,
123-
framework::Variable *var) const {
124-
auto *selectedRows = var->GetMutable<framework::SelectedRows>();
125-
// get device context from pool
126-
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
127-
auto &dev_ctx = *pool.Get(place);
128-
framework::DeserializeFromStream(fin, selectedRows, dev_ctx);
129-
selectedRows->SyncIndex();
130-
}
131-
};
132-
13357
} // namespace operators
13458
} // namespace paddle
13559
namespace ops = paddle::operators;

paddle/fluid/operators/load_op.cu

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/operators/load_op.h"
16+
17+
namespace ops = paddle::operators;
18+
19+
REGISTER_OP_CUDA_KERNEL(
20+
load, ops::LoadOpKernel<paddle::platform::CUDADeviceContext, float>,
21+
ops::LoadOpKernel<paddle::platform::CUDADeviceContext, double>,
22+
ops::LoadOpKernel<paddle::platform::CUDADeviceContext, int>,
23+
ops::LoadOpKernel<paddle::platform::CUDADeviceContext, int64_t>);

0 commit comments

Comments
 (0)