Skip to content

Commit cc75e84

Browse files
authored
Merge pull request #10541 from kexinzhao/load_fp16
Add float16 support to load op
2 parents 28de0ea + aa2635f commit cc75e84

File tree

2 files changed

+79
-1
lines changed

2 files changed

+79
-1
lines changed

paddle/fluid/operators/load_op.cc

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
1313
limitations under the License. */
1414
#include <fstream>
1515

16+
#include "paddle/fluid/framework/data_type_transform.h"
1617
#include "paddle/fluid/framework/op_registry.h"
1718
#include "paddle/fluid/platform/device_context.h"
1819
#include "paddle/fluid/platform/profiler.h"
@@ -46,6 +47,27 @@ class LoadOp : public framework::OperatorBase {
4647
auto *tensor = out_var->GetMutable<framework::LoDTensor>();
4748

4849
DeserializeFromStream(fin, tensor, *dev_ctx);
50+
51+
auto load_as_fp16 = Attr<bool>("load_as_fp16");
52+
auto in_dtype = framework::ToDataType(tensor->type());
53+
auto out_dtype = load_as_fp16 ? framework::proto::VarType::FP16 : in_dtype;
54+
55+
if (in_dtype != out_dtype) {
56+
// convert to float16 tensor
57+
auto in_kernel_type = framework::OpKernelType(in_dtype, place);
58+
auto out_kernel_type = framework::OpKernelType(out_dtype, place);
59+
framework::LoDTensor fp16_tensor;
60+
// copy LoD info to the new tensor
61+
fp16_tensor.set_lod(tensor->lod());
62+
framework::TransDataType(in_kernel_type, out_kernel_type, *tensor,
63+
&fp16_tensor);
64+
65+
// reset output tensor
66+
out_var->Clear();
67+
tensor = out_var->GetMutable<framework::LoDTensor>();
68+
tensor->set_lod(fp16_tensor.lod());
69+
tensor->ShareDataWith(fp16_tensor);
70+
}
4971
}
5072
};
5173

@@ -54,6 +76,13 @@ class LoadOpProtoMaker : public framework::OpProtoAndCheckerMaker {
5476
LoadOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker)
5577
: OpProtoAndCheckerMaker(proto, op_checker) {
5678
AddOutput("Out", "(Tensor) The tensor need to be loaded");
79+
AddAttr<bool>(
80+
"load_as_fp16",
81+
"(boolean, default false)"
82+
"If true, the tensor will be first loaded and then "
83+
"converted to float16 data type. Otherwise, the tensor will be "
84+
"directly loaded without data type conversion.")
85+
.SetDefault(false);
5786
AddAttr<std::string>("file_path",
5887
"(string) "
5988
"Variable will be loaded from \"file_path\".")

paddle/fluid/operators/save_load_op_test.cc

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ TEST(SaveLoadOp, CPU) {
6363
}
6464
}
6565

66-
TEST(SaveLoadFP16Op, CPU) {
66+
TEST(SaveFP16Op, CPU) {
6767
paddle::framework::Scope scope;
6868
paddle::platform::CPUPlace place;
6969

@@ -94,3 +94,52 @@ TEST(SaveLoadFP16Op, CPU) {
9494
EXPECT_EQ(expect[i], static_cast<float>(actual[i]));
9595
}
9696
}
97+
98+
TEST(LoadFP16Op, CPU) {
99+
paddle::framework::Scope scope;
100+
paddle::platform::CPUPlace place;
101+
102+
auto var = scope.Var("test_var");
103+
auto tensor = var->GetMutable<paddle::framework::LoDTensor>();
104+
tensor->Resize({3, 10});
105+
106+
paddle::framework::LoD expect_lod;
107+
expect_lod.resize(1);
108+
expect_lod[0].push_back(0);
109+
expect_lod[0].push_back(1);
110+
expect_lod[0].push_back(2);
111+
expect_lod[0].push_back(3);
112+
113+
tensor->set_lod(expect_lod);
114+
float* expect = tensor->mutable_data<float>(place);
115+
for (int64_t i = 0; i < tensor->numel(); ++i) {
116+
expect[i] = static_cast<float>(paddle::platform::float16(i));
117+
}
118+
119+
paddle::framework::AttributeMap attrs;
120+
attrs.insert({"file_path", std::string("tensor.save")});
121+
attrs.insert({"load_as_fp16", true});
122+
123+
auto save_op = paddle::framework::OpRegistry::CreateOp(
124+
"save", {{"X", {"test_var"}}}, {}, attrs);
125+
save_op->Run(scope, place);
126+
127+
auto load_var = scope.Var("out_var");
128+
auto load_op = paddle::framework::OpRegistry::CreateOp(
129+
"load", {}, {{"Out", {"out_var"}}}, attrs);
130+
load_op->Run(scope, place);
131+
132+
auto target = load_var->Get<paddle::framework::LoDTensor>();
133+
paddle::platform::float16* actual = target.data<paddle::platform::float16>();
134+
for (int64_t i = 0; i < tensor->numel(); ++i) {
135+
EXPECT_EQ(expect[i], static_cast<float>(actual[i]));
136+
}
137+
138+
auto& actual_lod = target.lod();
139+
EXPECT_EQ(expect_lod.size(), actual_lod.size());
140+
for (size_t i = 0; i < expect_lod.size(); ++i) {
141+
for (size_t j = 0; j < expect_lod[i].size(); ++j) {
142+
EXPECT_EQ(expect_lod[i][j], actual_lod[i][j]);
143+
}
144+
}
145+
}

0 commit comments

Comments
 (0)