Skip to content

Commit 28a6037

Browse files
authored
Fix lod check in FP16 test for save_op (#10508)
1 parent ce72c3f commit 28a6037

File tree

1 file changed

+14
-0
lines changed

1 file changed

+14
-0
lines changed

paddle/fluid/operators/save_load_op_test.cc

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,14 @@ TEST(SaveFP16Op, CPU) {
7070
auto var = scope.Var("test_var");
7171
auto tensor = var->GetMutable<paddle::framework::LoDTensor>();
7272
tensor->Resize({3, 10});
73+
paddle::framework::LoD expect_lod;
74+
expect_lod.resize(1);
75+
expect_lod[0].push_back(0);
76+
expect_lod[0].push_back(1);
77+
expect_lod[0].push_back(2);
78+
expect_lod[0].push_back(3);
7379

80+
tensor->set_lod(expect_lod);
7481
float* expect = tensor->mutable_data<float>(place);
7582
for (int64_t i = 0; i < tensor->numel(); ++i) {
7683
expect[i] = static_cast<float>(paddle::platform::float16(i));
@@ -93,6 +100,13 @@ TEST(SaveFP16Op, CPU) {
93100
for (int64_t i = 0; i < tensor->numel(); ++i) {
94101
EXPECT_EQ(expect[i], static_cast<float>(actual[i]));
95102
}
103+
auto& actual_lod = target->lod();
104+
EXPECT_EQ(expect_lod.size(), actual_lod.size());
105+
for (size_t i = 0; i < expect_lod.size(); ++i) {
106+
for (size_t j = 0; j < expect_lod[i].size(); ++j) {
107+
EXPECT_EQ(expect_lod[i][j], actual_lod[i][j]);
108+
}
109+
}
96110
}
97111

98112
TEST(LoadFP16Op, CPU) {

0 commit comments

Comments
 (0)