Skip to content

Commit 12a3cea

Browse files
chengduoabhinavarora
authored andcommitted
Add tuple type (#8519)
* add the type of tuple * add lod_tensor to tuple
1 parent d3fbede commit 12a3cea

File tree

4 files changed

+141
-0
lines changed

4 files changed

+141
-0
lines changed

paddle/fluid/framework/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,5 +96,6 @@ cc_test(op_kernel_type_test SRCS op_kernel_type_test.cc DEPS place device_contex
9696
cc_test(cow_ptr_tests SRCS details/cow_ptr_test.cc)
9797

9898
cc_test(channel_test SRCS channel_test.cc)
99+
cc_test(tuple_test SRCS tuple_test.cc )
99100
cc_test(concurrency_test SRCS concurrency_test.cc DEPS go_op channel_close_op channel_create_op
100101
channel_send_op channel_recv_op sum_op elementwise_add_op executor proto_desc)

paddle/fluid/framework/framework.proto

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ message VarType {
117117
// raw variables should manage their own allocations
118118
// in operators like nccl_op
119119
RAW = 17;
120+
TUPLE = 18;
120121
}
121122

122123
required Type type = 1;
@@ -148,6 +149,9 @@ message VarType {
148149
required int64 capacity = 2;
149150
}
150151
optional ChannelDesc channel = 6;
152+
153+
message Tuple { repeated Type element_type = 1; }
154+
optional Tuple tuple = 7;
151155
}
152156

153157
message VarDesc {

paddle/fluid/framework/tuple.h

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
17+
#include <stdexcept>
18+
#include <string>
19+
#include <vector>
20+
#include "paddle/fluid/framework/channel.h"
21+
#include "paddle/fluid/framework/lod_tensor.h"
22+
#include "paddle/fluid/framework/tensor.h"
23+
#include "paddle/fluid/framework/var_desc.h"
24+
#include "paddle/fluid/platform/enforce.h"
25+
#include "paddle/fluid/platform/variant.h"
26+
27+
namespace paddle {
28+
namespace framework {
29+
30+
typedef boost::variant<int, int64_t, float, double, std::string, Tensor,
31+
LoDTensor /*, ChannelHolder*/>
32+
ElementVar;
33+
34+
class Tuple {
35+
public:
36+
using ElementVars = std::vector<ElementVar>;
37+
38+
Tuple(std::vector<ElementVar>& var, std::vector<VarDesc>& var_desc)
39+
: var_(var), var_desc_(var_desc) {}
40+
Tuple(std::vector<ElementVar>& var) : var_(var) {}
41+
42+
ElementVar get(int idx) const { return var_[idx]; };
43+
44+
ElementVar& get(int idx) { return var_[idx]; };
45+
46+
bool isSameType(Tuple& t) const;
47+
48+
size_t getSize() const { return var_.size(); };
49+
50+
private:
51+
ElementVars var_;
52+
std::vector<VarDesc> var_desc_;
53+
};
54+
55+
bool Tuple::isSameType(Tuple& t) const {
56+
size_t tuple_size = getSize();
57+
if (tuple_size != t.getSize()) {
58+
return false;
59+
}
60+
for (size_t j = 0; j < tuple_size; ++j) {
61+
auto type1 = get(j).which();
62+
auto type2 = t.get(j).which();
63+
if (type1 != type2) return false;
64+
}
65+
return true;
66+
}
67+
68+
Tuple* make_tuple(std::vector<ElementVar> tuple) { return new Tuple(tuple); }
69+
70+
} // namespace framework
71+
} // namespace paddle

paddle/fluid/framework/tuple_test.cc

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
#include <sstream>
15+
#include <vector>
16+
17+
#include "gtest/gtest.h"
18+
#include "paddle/fluid/framework/tuple.h"
19+
20+
TEST(Tuple, Make) {
21+
std::vector<paddle::framework::ElementVar> element_type;
22+
element_type.push_back(12);
23+
element_type.push_back(12.0f);
24+
element_type.push_back("ElementVar");
25+
26+
paddle::framework::Tuple* tuple = paddle::framework::make_tuple(element_type);
27+
28+
EXPECT_EQ(boost::get<int>(tuple->get(0)), 12);
29+
EXPECT_EQ(boost::get<float>(tuple->get(1)), 12.0f);
30+
EXPECT_EQ(boost::get<std::string>(tuple->get(2)), "ElementVar");
31+
32+
delete tuple;
33+
}
34+
35+
TEST(Tuple, IsTheSameType) {
36+
std::vector<paddle::framework::ElementVar> element_type1;
37+
std::vector<paddle::framework::ElementVar> element_type2;
38+
std::vector<paddle::framework::ElementVar> element_type3;
39+
40+
element_type1.push_back(12);
41+
element_type1.push_back(12.0f);
42+
element_type1.push_back("Tuple1");
43+
44+
element_type2.push_back(13);
45+
element_type2.push_back(13.0f);
46+
element_type2.push_back("Tuple2");
47+
48+
element_type3.push_back(14.0f);
49+
element_type3.push_back(14);
50+
element_type3.push_back("Tuple3");
51+
52+
paddle::framework::Tuple* tuple1 =
53+
paddle::framework::make_tuple(element_type1);
54+
paddle::framework::Tuple* tuple2 =
55+
paddle::framework::make_tuple(element_type2);
56+
paddle::framework::Tuple* tuple3 =
57+
paddle::framework::make_tuple(element_type3);
58+
59+
EXPECT_TRUE(tuple1->isSameType(*tuple2));
60+
EXPECT_FALSE(tuple1->isSameType(*tuple3));
61+
62+
delete tuple1;
63+
delete tuple2;
64+
delete tuple3;
65+
}

0 commit comments

Comments
 (0)