@@ -46,6 +46,19 @@ class LoadOp : public framework::OperatorBase {
46
46
auto *tensor = out_var->GetMutable <framework::LoDTensor>();
47
47
48
48
DeserializeFromStream (fin, tensor, *dev_ctx);
49
+
50
+ auto load_as_fp16 = Attr<bool >(" load_as_fp16" );
51
+ auto in_dtype = framework::ToDataType (tensor->type ());
52
+ auto out_dtype = load_as_fp16 ? framework::proto::VarType::FP16 : in_dtype;
53
+
54
+ if (in_dtype != out_dtype) {
55
+ // convert to float16 tensor
56
+ auto in_kernel_type = framework::OpKernelType (in_dtype, place);
57
+ auto out_kernel_type = framework::OpKernelType (out_dtype, place);
58
+ framework::LoDTensor fp16_tensor;
59
+ framework::TransDataType (in_kernel_type, out_kernel_type, *tensor,
60
+ &fp16_tensor);
61
+ }
49
62
}
50
63
};
51
64
@@ -54,6 +67,13 @@ class LoadOpProtoMaker : public framework::OpProtoAndCheckerMaker {
54
67
LoadOpProtoMaker (OpProto *proto, OpAttrChecker *op_checker)
55
68
: OpProtoAndCheckerMaker(proto, op_checker) {
56
69
AddOutput (" Out" , " (Tensor) The tensor need to be loaded" );
70
+ AddAttr<bool >(
71
+ " load_as_fp16" ,
72
+ " (boolean, default false)"
73
+ " If true, the tensor will be first loaded and then "
74
+ " converted to float16 data type. Otherwise, the tensor will be "
75
+ " directly loaded without data type conversion." )
76
+ .SetDefault (false );
57
77
AddAttr<std::string>(" file_path" ,
58
78
" (string) "
59
79
" Variable will be loaded from \" file_path\" ." )
0 commit comments