@@ -18,6 +18,7 @@ limitations under the License. */
18
18
#include < numeric>
19
19
20
20
#include " paddle/fluid/framework/data_type.h"
21
+ #include " paddle/fluid/framework/data_type_transform.h"
21
22
#include " paddle/fluid/framework/framework.pb.h"
22
23
#include " paddle/fluid/framework/lod_tensor.h"
23
24
#include " paddle/fluid/framework/op_registry.h"
@@ -68,6 +69,7 @@ class SaveOp : public framework::OperatorBase {
68
69
const platform::Place &place) const override {
69
70
auto filename = Attr<std::string>(" file_path" );
70
71
auto overwrite = Attr<bool >(" overwrite" );
72
+ auto save_as_fp16 = Attr<bool >(" save_as_fp16" );
71
73
72
74
if (FileExists (filename) && !overwrite) {
73
75
PADDLE_THROW (" %s is existed, cannot save to it when overwrite=false" ,
@@ -96,7 +98,18 @@ class SaveOp : public framework::OperatorBase {
96
98
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance ();
97
99
auto &dev_ctx = *pool.Get (place);
98
100
99
- framework::SerializeToStream (fout, tensor, dev_ctx);
101
+ auto in_dtype = framework::ToDataType (tensor.type ());
102
+ auto out_dtype = save_as_fp16 ? framework::proto::VarType::FP16 : in_dtype;
103
+
104
+ if (in_dtype != out_dtype) {
105
+ auto in_kernel_type = framework::OpKernelType (in_dtype, place);
106
+ auto out_kernel_type = framework::OpKernelType (out_dtype, place);
107
+ framework::LoDTensor out;
108
+ framework::TransDataType (in_kernel_type, out_kernel_type, tensor, &out);
109
+ framework::SerializeToStream (fout, out, dev_ctx);
110
+ } else {
111
+ framework::SerializeToStream (fout, tensor, dev_ctx);
112
+ }
100
113
}
101
114
};
102
115
@@ -114,6 +127,12 @@ This operator will serialize and write a tensor variable to file on disk.
114
127
" (boolean, default true)"
115
128
" Overwrite the output file if exist" )
116
129
.SetDefault (true );
130
+ AddAttr<bool >(" save_as_fp16" ,
131
+ " (boolean, default false)"
132
+ " If true, the tensor will be converted to float16 data "
133
+ " type and then saved. Otherwise, the tensor will be "
134
+ " directly saved without data type conversion." )
135
+ .SetDefault (false );
117
136
AddAttr<std::string>(" file_path" ,
118
137
" (string)"
119
138
" The \" file_path\" where the variable will be saved." )
0 commit comments