@@ -11,89 +11,27 @@ distributed under the License is distributed on an "AS IS" BASIS,
11
11
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
See the License for the specific language governing permissions and
13
13
limitations under the License. */
14
- #include < fstream>
15
- #include " paddle/fluid/framework/data_type_transform.h"
16
- #include " paddle/fluid/framework/op_registry.h"
17
- #include " paddle/fluid/platform/device_context.h"
14
+
15
+ #include < string>
16
+ #include < vector>
17
+
18
+ #include " paddle/fluid/operators/load_combine_op.h"
18
19
19
20
namespace paddle {
20
21
namespace operators {
21
22
22
- class LoadCombineOp : public framework ::OperatorBase {
23
+ class LoadCombineOp : public framework ::OperatorWithKernel {
23
24
public:
24
- LoadCombineOp (const std::string &type,
25
- const framework::VariableNameMap &inputs,
26
- const framework::VariableNameMap &outputs,
27
- const framework::AttributeMap &attrs)
28
- : OperatorBase(type, inputs, outputs, attrs) {}
29
-
30
- private:
31
- void RunImpl (const framework::Scope &scope,
32
- const platform::Place &place) const override {
33
- auto filename = Attr<std::string>(" file_path" );
34
- auto load_as_fp16 = Attr<bool >(" load_as_fp16" );
35
- auto model_from_memory = Attr<bool >(" model_from_memory" );
36
- auto out_var_names = Outputs (" Out" );
37
- PADDLE_ENFORCE_GT (
38
- static_cast <int >(out_var_names.size ()), 0 ,
39
- " The number of output variables should be greater than 0." );
40
- if (!model_from_memory) {
41
- std::ifstream fin (filename, std::ios::binary);
42
- PADDLE_ENFORCE (static_cast <bool >(fin),
43
- " Cannot open file %s for load_combine op" , filename);
44
- LoadParamsFromBuffer (scope, place, &fin, load_as_fp16, out_var_names);
45
- } else {
46
- PADDLE_ENFORCE (!filename.empty (), " Cannot load file from memory" );
47
- std::stringstream fin (filename, std::ios::in | std::ios::binary);
48
- LoadParamsFromBuffer (scope, place, &fin, load_as_fp16, out_var_names);
49
- }
50
- }
51
- void LoadParamsFromBuffer (
52
- const framework::Scope &scope, const platform::Place &place,
53
- std::istream *buffer, bool load_as_fp16,
54
- const std::vector<std::string> &out_var_names) const {
55
- platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance ();
56
- auto &dev_ctx = *pool.Get (place);
57
-
58
- for (size_t i = 0 ; i < out_var_names.size (); i++) {
59
- auto *out_var = scope.FindVar (out_var_names[i]);
60
-
61
- PADDLE_ENFORCE (out_var != nullptr , " Output variable %s cannot be found" ,
62
- out_var_names[i]);
63
-
64
- auto *tensor = out_var->GetMutable <framework::LoDTensor>();
65
-
66
- // Error checking
67
- PADDLE_ENFORCE (static_cast <bool >(*buffer), " Cannot read more" );
68
-
69
- // Get data from fin to tensor
70
- DeserializeFromStream (*buffer, tensor, dev_ctx);
71
-
72
- auto in_dtype = tensor->type ();
73
- auto out_dtype =
74
- load_as_fp16 ? framework::proto::VarType::FP16 : in_dtype;
75
-
76
- if (in_dtype != out_dtype) {
77
- // convert to float16 tensor
78
- auto in_kernel_type = framework::OpKernelType (in_dtype, place);
79
- auto out_kernel_type = framework::OpKernelType (out_dtype, place);
80
- framework::LoDTensor fp16_tensor;
81
- // copy LoD info to the new tensor
82
- fp16_tensor.set_lod (tensor->lod ());
83
- framework::TransDataType (in_kernel_type, out_kernel_type, *tensor,
84
- &fp16_tensor);
85
-
86
- // reset output tensor
87
- out_var->Clear ();
88
- tensor = out_var->GetMutable <framework::LoDTensor>();
89
- tensor->set_lod (fp16_tensor.lod ());
90
- tensor->ShareDataWith (fp16_tensor);
91
- }
92
- }
93
- buffer->peek ();
94
- PADDLE_ENFORCE (buffer->eof (),
95
- " You are not allowed to load partial data via "
96
- " load_combine_op, use load_op instead." );
25
+ using framework::OperatorWithKernel::OperatorWithKernel;
26
+
27
+ void InferShape (framework::InferShapeContext *ctx) const override {}
28
+
29
+ protected:
30
+ framework::OpKernelType GetExpectedKernelType (
31
+ const framework::ExecutionContext &ctx) const override {
32
+ framework::OpKernelType kt = framework::OpKernelType (
33
+ framework::proto::VarType::FP32, ctx.GetPlace ());
34
+ return kt;
97
35
}
98
36
};
99
37
@@ -124,21 +62,30 @@ class LoadCombineOpProtoMaker : public framework::OpProtoAndCheckerMaker {
124
62
AddComment (R"DOC(
125
63
LoadCombine Operator.
126
64
127
- LoadCombine operator loads LoDTensor variables from a file, which could be
128
- loaded in memory already. The file should contain one or more LoDTensors
65
+ LoadCombine operator loads LoDTensor variables from a file, which could be
66
+ loaded in memory already. The file should contain one or more LoDTensors
129
67
serialized using the SaveCombine operator. The
130
- LoadCombine operator applies a deserialization strategy to appropriately load
131
- the LodTensors, and this strategy complements the serialization strategy used
68
+ LoadCombine operator applies a deserialization strategy to appropriately load
69
+ the LodTensors, and this strategy complements the serialization strategy used
132
70
in the SaveCombine operator. Hence, the LoadCombine operator is tightly coupled
133
- with the SaveCombine operator, and can only deserialize one or more LoDTensors
71
+ with the SaveCombine operator, and can only deserialize one or more LoDTensors
134
72
that were saved using the SaveCombine operator.
135
73
136
74
)DOC" );
137
75
}
138
76
};
77
+
139
78
} // namespace operators
140
79
} // namespace paddle
80
+
141
81
namespace ops = paddle::operators;
142
82
143
83
REGISTER_OPERATOR (load_combine, ops::LoadCombineOp,
144
84
ops::LoadCombineOpProtoMaker);
85
+
86
+ REGISTER_OP_CPU_KERNEL (
87
+ load_combine,
88
+ ops::LoadCombineOpKernel<paddle::platform::CPUDeviceContext, float >,
89
+ ops::LoadCombineOpKernel<paddle::platform::CPUDeviceContext, double >,
90
+ ops::LoadCombineOpKernel<paddle::platform::CPUDeviceContext, int >,
91
+ ops::LoadCombineOpKernel<paddle::platform::CPUDeviceContext, int64_t >);
0 commit comments