Skip to content

Commit eb95417

Browse files
committed
initial commit
1 parent 0446220 commit eb95417

File tree

1 file changed

+20
-0
lines changed

1 file changed

+20
-0
lines changed

paddle/fluid/operators/load_op.cc

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,19 @@ class LoadOp : public framework::OperatorBase {
4646
auto *tensor = out_var->GetMutable<framework::LoDTensor>();
4747

4848
DeserializeFromStream(fin, tensor, *dev_ctx);
49+
50+
auto load_as_fp16 = Attr<bool>("load_as_fp16");
51+
auto in_dtype = framework::ToDataType(tensor->type());
52+
auto out_dtype = load_as_fp16 ? framework::proto::VarType::FP16 : in_dtype;
53+
54+
if (in_dtype != out_dtype) {
55+
// convert to float16 tensor
56+
auto in_kernel_type = framework::OpKernelType(in_dtype, place);
57+
auto out_kernel_type = framework::OpKernelType(out_dtype, place);
58+
framework::LoDTensor fp16_tensor;
59+
framework::TransDataType(in_kernel_type, out_kernel_type, *tensor,
60+
&fp16_tensor);
61+
}
4962
}
5063
};
5164

@@ -54,6 +67,13 @@ class LoadOpProtoMaker : public framework::OpProtoAndCheckerMaker {
5467
LoadOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker)
5568
: OpProtoAndCheckerMaker(proto, op_checker) {
5669
AddOutput("Out", "(Tensor) The tensor need to be loaded");
70+
AddAttr<bool>(
71+
"load_as_fp16",
72+
"(boolean, default false)"
73+
"If true, the tensor will be first loaded and then "
74+
"converted to float16 data type. Otherwise, the tensor will be "
75+
"directly loaded without data type conversion.")
76+
.SetDefault(false);
5777
AddAttr<std::string>("file_path",
5878
"(string) "
5979
"Variable will be loaded from \"file_path\".")

0 commit comments

Comments
 (0)