Skip to content

Commit 4313d87

Browse files
committed
refine
1 parent c69cf6d commit 4313d87

File tree

6 files changed

+37
-11
lines changed

6 files changed

+37
-11
lines changed

paddle/fluid/framework/CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,9 @@ else()
5656
cc_test(mixed_vector_test SRCS mixed_vector_test.cc DEPS place memory device_context tensor)
5757
endif()
5858
if (NOT WIN32)
59-
cc_library(lod_tensor SRCS lod_tensor.cc DEPS ddim place tensor framework_proto recordio)
59+
cc_library(lod_tensor SRCS lod_tensor.cc DEPS ddim place tensor framework_proto recordio version)
6060
else()
61-
cc_library(lod_tensor SRCS lod_tensor.cc DEPS ddim place tensor framework_proto)
61+
cc_library(lod_tensor SRCS lod_tensor.cc DEPS ddim place tensor framework_proto version)
6262
endif (NOT WIN32)
6363

6464
cc_test(lod_tensor_test SRCS lod_tensor_test.cc DEPS lod_tensor memory)

paddle/fluid/framework/lod_tensor.cc

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ limitations under the License. */
2121
#include "paddle/fluid/framework/framework.pb.h"
2222
#include "paddle/fluid/framework/lod_tensor.h"
2323
#include "paddle/fluid/framework/var_type.h"
24+
#include "paddle/fluid/framework/version.h"
2425

2526
#include "paddle/fluid/memory/memcpy.h"
2627
#include "paddle/fluid/memory/memory.h"
@@ -251,8 +252,8 @@ void AppendLoD(LoD *lod, const LoD &lod_length) {
251252
void SerializeToStream(std::ostream &os, const LoDTensor &tensor,
252253
const platform::DeviceContext &dev_ctx) {
253254
{ // the 1st field, uint32_t version for LoDTensor
254-
constexpr uint32_t version = 0;
255-
os.write(reinterpret_cast<const char *>(&version), sizeof(version));
255+
os.write(reinterpret_cast<const char *>(&kCurTensorVersion),
256+
sizeof(kCurTensorVersion));
256257
}
257258
{
258259
// the 2st field, LoD information
@@ -281,6 +282,8 @@ void DeserializeFromStream(std::istream &is, LoDTensor *tensor,
281282
// the 1st field, unit32_t version for LoDTensor
282283
uint32_t version;
283284
is.read(reinterpret_cast<char *>(&version), sizeof(version));
285+
PADDLE_ENFORCE(framework::IsTensorVersionSupported(version),
286+
"tensor version %u is not supported.", version);
284287
PADDLE_ENFORCE_EQ(version, 0U, "Only version 0 is supported");
285288
}
286289
{

paddle/fluid/framework/version.cc

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,20 @@ limitations under the License. */
1717

1818
namespace paddle {
1919
namespace framework {
20-
bool IsProgramVersionSupported(int version) {
20+
bool IsProgramVersionSupported(int64_t version) {
2121
static int num_supported =
2222
sizeof(kSupportedProgramVersion) / sizeof(kSupportedProgramVersion[0]);
2323
return std::find(kSupportedProgramVersion,
2424
kSupportedProgramVersion + num_supported,
2525
version) != kSupportedProgramVersion + num_supported;
2626
}
27+
28+
bool IsTensorVersionSupported(uint32_t version) {
29+
static int num_supported =
30+
sizeof(kSupportedTensorVersion) / sizeof(kSupportedTensorVersion[0]);
31+
return std::find(kSupportedTensorVersion,
32+
kSupportedTensorVersion + num_supported,
33+
version) != kSupportedTensorVersion + num_supported;
34+
}
2735
} // namespace framework
2836
} // namespace paddle

paddle/fluid/framework/version.h

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,19 +12,28 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

15+
#include <cstdint>
16+
1517
#pragma once
1618

1719
namespace paddle {
1820
namespace framework {
1921

2022
// The program version the current codes generate.
21-
constexpr int kCurProgramVersion = 0;
23+
constexpr int64_t kCurProgramVersion = 0;
2224

2325
// The program version that was generated by previous or current codes
2426
// and supported by current codes.
25-
constexpr int kSupportedProgramVersion[] = {0};
27+
constexpr int64_t kSupportedProgramVersion[] = {0};
28+
29+
// Due to historical reasons, tensor version use uint32_t.
30+
constexpr uint32_t kCurTensorVersion = 0;
31+
32+
constexpr uint32_t kSupportedTensorVersion[] = {0};
33+
34+
bool IsProgramVersionSupported(int64_t version);
2635

27-
bool IsProgramVersionSupported(int version);
36+
bool IsTensorVersionSupported(uint32_t version);
2837

2938
} // namespace framework
3039
} // namespace paddle

paddle/fluid/framework/version_test.cc

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,14 @@
1717

1818
namespace paddle {
1919
namespace framework {
20-
TEST(Variable, GetMutable) {
20+
TEST(Version, Basic) {
2121
EXPECT_TRUE(IsProgramVersionSupported(0));
2222
EXPECT_FALSE(IsProgramVersionSupported(1));
2323
EXPECT_FALSE(IsProgramVersionSupported(-1));
24+
25+
EXPECT_TRUE(IsTensorVersionSupported(0));
26+
EXPECT_FALSE(IsTensorVersionSupported(1));
27+
EXPECT_FALSE(IsTensorVersionSupported(-1));
2428
}
2529
} // namespace framework
2630
} // namespace paddle

paddle/fluid/inference/io.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,8 @@ std::unique_ptr<framework::ProgramDesc> Load(framework::Executor* executor,
126126
std::unique_ptr<framework::ProgramDesc> main_program(
127127
new framework::ProgramDesc(program_desc_str));
128128
PADDLE_ENFORCE(framework::IsProgramVersionSupported(main_program->Version()),
129-
"model version %d is not supported.", main_program->Version());
129+
"model version %ld is not supported.",
130+
main_program->Version());
130131

131132
LoadPersistables(executor, scope, *main_program, dirname, "");
132133
return main_program;
@@ -142,7 +143,8 @@ std::unique_ptr<framework::ProgramDesc> Load(
142143
std::unique_ptr<framework::ProgramDesc> main_program(
143144
new framework::ProgramDesc(program_desc_str));
144145
PADDLE_ENFORCE(framework::IsProgramVersionSupported(main_program->Version()),
145-
"model version %d is not supported.", main_program->Version());
146+
"model version %ld is not supported.",
147+
main_program->Version());
146148

147149
LoadPersistables(executor, scope, *main_program, "", param_filename);
148150
return main_program;

0 commit comments

Comments
 (0)