Skip to content

Commit 17bf871

Browse files
authored
Merge pull request #12988 from panyx0718/ir2
program and tensor versioning support
2 parents a557608 + e762d85 commit 17bf871

File tree

13 files changed

+164
-7
lines changed

13 files changed

+164
-7
lines changed

paddle/fluid/framework/CMakeLists.txt

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,9 @@ else()
5656
cc_test(mixed_vector_test SRCS mixed_vector_test.cc DEPS place memory device_context tensor)
5757
endif()
5858
if (NOT WIN32)
59-
cc_library(lod_tensor SRCS lod_tensor.cc DEPS ddim place tensor framework_proto recordio)
59+
cc_library(lod_tensor SRCS lod_tensor.cc DEPS ddim place tensor framework_proto recordio version)
6060
else()
61-
cc_library(lod_tensor SRCS lod_tensor.cc DEPS ddim place tensor framework_proto)
61+
cc_library(lod_tensor SRCS lod_tensor.cc DEPS ddim place tensor framework_proto version)
6262
endif (NOT WIN32)
6363

6464
cc_test(lod_tensor_test SRCS lod_tensor_test.cc DEPS lod_tensor memory)
@@ -116,7 +116,11 @@ cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope gl
116116
endif(NOT WIN32)
117117

118118
cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry device_context)
119-
cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS shape_inference op_info operator glog)
119+
120+
cc_library(version SRCS version.cc)
121+
cc_test(version_test SRCS version_test.cc DEPS version)
122+
123+
cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS shape_inference op_info operator glog version)
120124

121125
cc_library(op_registry SRCS op_registry.cc DEPS op_proto_maker op_info operator glog proto_desc)
122126
nv_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry)

paddle/fluid/framework/framework.proto

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,13 @@ syntax = "proto2";
1616
option optimize_for = LITE_RUNTIME;
1717
package paddle.framework.proto;
1818

19+
// Any incompatible changes to ProgramDesc and its dependencies should
20+
// raise the version defined version.h.
21+
//
22+
// Serailization and Deserialization codes should be modified in a way
23+
// that supports old versions following the version and compatibility policy.
24+
message Version { optional int64 version = 1 [ default = 0 ]; }
25+
1926
enum AttrType {
2027
INT = 0;
2128
FLOAT = 1;
@@ -180,4 +187,8 @@ message BlockDesc {
180187
// for more details.
181188
// TODO(panyx0718): A model can have multiple programs. Need a
182189
// way to distinguish them. Maybe ID or name?
183-
message ProgramDesc { repeated BlockDesc blocks = 1; }
190+
message ProgramDesc {
191+
repeated BlockDesc blocks = 1;
192+
193+
optional Version version = 2;
194+
}

paddle/fluid/framework/lod_tensor.cc

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ limitations under the License. */
2121
#include "paddle/fluid/framework/framework.pb.h"
2222
#include "paddle/fluid/framework/lod_tensor.h"
2323
#include "paddle/fluid/framework/var_type.h"
24+
#include "paddle/fluid/framework/version.h"
2425

2526
#include "paddle/fluid/memory/memcpy.h"
2627
#include "paddle/fluid/memory/memory.h"
@@ -251,8 +252,8 @@ void AppendLoD(LoD *lod, const LoD &lod_length) {
251252
void SerializeToStream(std::ostream &os, const LoDTensor &tensor,
252253
const platform::DeviceContext &dev_ctx) {
253254
{ // the 1st field, uint32_t version for LoDTensor
254-
constexpr uint32_t version = 0;
255-
os.write(reinterpret_cast<const char *>(&version), sizeof(version));
255+
os.write(reinterpret_cast<const char *>(&kCurTensorVersion),
256+
sizeof(kCurTensorVersion));
256257
}
257258
{
258259
// the 2st field, LoD information
@@ -281,6 +282,8 @@ void DeserializeFromStream(std::istream &is, LoDTensor *tensor,
281282
// the 1st field, unit32_t version for LoDTensor
282283
uint32_t version;
283284
is.read(reinterpret_cast<char *>(&version), sizeof(version));
285+
PADDLE_ENFORCE(framework::IsTensorVersionSupported(version),
286+
"tensor version %u is not supported.", version);
284287
PADDLE_ENFORCE_EQ(version, 0U, "Only version 0 is supported");
285288
}
286289
{

paddle/fluid/framework/program_desc.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ limitations under the License. */
1515
#include "paddle/fluid/framework/program_desc.h"
1616
#include "paddle/fluid/framework/block_desc.h"
1717
#include "paddle/fluid/framework/feed_fetch_type.h"
18+
#include "paddle/fluid/framework/version.h"
1819

1920
namespace paddle {
2021
namespace framework {
@@ -38,7 +39,10 @@ proto::ProgramDesc *ProgramDesc::Proto() {
3839
return &desc_;
3940
}
4041

42+
int64_t ProgramDesc::Version() const { return desc_.version().version(); }
43+
4144
ProgramDesc::ProgramDesc() {
45+
desc_.mutable_version()->set_version(kCurProgramVersion);
4246
auto *block = desc_.mutable_blocks()->Add();
4347
block->set_idx(kRootBlockIndex);
4448
block->set_parent_idx(kNoneBlockIndex);

paddle/fluid/framework/program_desc.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,8 @@ class ProgramDesc {
5757

5858
proto::ProgramDesc *Proto();
5959

60+
int64_t Version() const;
61+
6062
// The output variable of feed_op is referenced as feed_target.
6163
// This function is used to collect the output variable's name of all
6264
// feed_ops.

paddle/fluid/framework/version.cc

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/framework/version.h"
16+
#include <algorithm>
17+
18+
namespace paddle {
19+
namespace framework {
20+
bool IsProgramVersionSupported(int64_t version) {
21+
static int num_supported =
22+
sizeof(kSupportedProgramVersion) / sizeof(kSupportedProgramVersion[0]);
23+
return std::find(kSupportedProgramVersion,
24+
kSupportedProgramVersion + num_supported,
25+
version) != kSupportedProgramVersion + num_supported;
26+
}
27+
28+
bool IsTensorVersionSupported(uint32_t version) {
29+
static int num_supported =
30+
sizeof(kSupportedTensorVersion) / sizeof(kSupportedTensorVersion[0]);
31+
return std::find(kSupportedTensorVersion,
32+
kSupportedTensorVersion + num_supported,
33+
version) != kSupportedTensorVersion + num_supported;
34+
}
35+
} // namespace framework
36+
} // namespace paddle

paddle/fluid/framework/version.h

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include <cstdint>
16+
17+
#pragma once
18+
19+
namespace paddle {
20+
namespace framework {
21+
22+
// Note:
23+
// Program and Tensor that pass the IsXXXVersionSupported should
24+
// be supported by the current codes. Otherwise, it's a compatibility
25+
// bug.
26+
27+
// The program version the current codes generate.
28+
constexpr int64_t kCurProgramVersion = 0;
29+
30+
// The program version that was generated by previous or current codes
31+
// and supported by current codes.
32+
constexpr int64_t kSupportedProgramVersion[] = {0};
33+
34+
// Due to historical reasons, tensor version use uint32_t.
35+
// The tensor version the current codes generate.
36+
constexpr uint32_t kCurTensorVersion = 0;
37+
38+
// The tensor version that was generated by previous or current codes
39+
// and supported by current codes.
40+
constexpr uint32_t kSupportedTensorVersion[] = {0};
41+
42+
bool IsProgramVersionSupported(int64_t version);
43+
44+
bool IsTensorVersionSupported(uint32_t version);
45+
46+
} // namespace framework
47+
} // namespace paddle
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/fluid/framework/version.h"
16+
#include "gtest/gtest.h"
17+
18+
namespace paddle {
19+
namespace framework {
20+
TEST(Version, Basic) {
21+
EXPECT_TRUE(IsProgramVersionSupported(0));
22+
EXPECT_FALSE(IsProgramVersionSupported(1));
23+
EXPECT_FALSE(IsProgramVersionSupported(-1));
24+
25+
EXPECT_TRUE(IsTensorVersionSupported(0));
26+
EXPECT_FALSE(IsTensorVersionSupported(1));
27+
EXPECT_FALSE(IsTensorVersionSupported(-1));
28+
}
29+
} // namespace framework
30+
} // namespace paddle

paddle/fluid/inference/io.cc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ limitations under the License. */
2020
#include "paddle/fluid/framework/block_desc.h"
2121
#include "paddle/fluid/framework/feed_fetch_type.h"
2222
#include "paddle/fluid/framework/op_registry.h"
23+
#include "paddle/fluid/framework/version.h"
2324
#include "paddle/fluid/platform/cpu_helper.h"
2425
#include "paddle/fluid/pybind/pybind.h"
2526

@@ -124,6 +125,9 @@ std::unique_ptr<framework::ProgramDesc> Load(framework::Executor* executor,
124125

125126
std::unique_ptr<framework::ProgramDesc> main_program(
126127
new framework::ProgramDesc(program_desc_str));
128+
PADDLE_ENFORCE(framework::IsProgramVersionSupported(main_program->Version()),
129+
"model version %ld is not supported.",
130+
main_program->Version());
127131

128132
LoadPersistables(executor, scope, *main_program, dirname, "");
129133
return main_program;
@@ -138,6 +142,9 @@ std::unique_ptr<framework::ProgramDesc> Load(
138142

139143
std::unique_ptr<framework::ProgramDesc> main_program(
140144
new framework::ProgramDesc(program_desc_str));
145+
PADDLE_ENFORCE(framework::IsProgramVersionSupported(main_program->Version()),
146+
"model version %ld is not supported.",
147+
main_program->Version());
141148

142149
LoadPersistables(executor, scope, *main_program, "", param_filename);
143150
return main_program;

paddle/fluid/pybind/protobuf.cc

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,10 @@ void BindProgramDesc(pybind11::module *m) {
137137
PADDLE_ENFORCE(desc->ParseFromString(data),
138138
"Fail to parse ProgramDesc from string. This could "
139139
"be a bug of Paddle.");
140-
});
140+
})
141+
.def("_version", [](pd::ProgramDesc &self) -> int64_t {
142+
return self.Proto()->version().version();
143+
});
141144
}
142145

143146
void BindBlockDesc(pybind11::module *m) {

0 commit comments

Comments
 (0)