Skip to content

Commit 77ee8fb

Browse files
author
kavyasrinet
authored
Exposing Channel to be used as a Variable and integrating with Fluid (#8486)
* Adding set_capacity method support * Adding Python for make_channel * Updating notest_concurrency * Write python for make_channel method * Write python for make_channel method * Fix make_channel and test * Placeholder ops for channel send, recv and close * Adding ToTypeIndex method to var_type.h * Add var_type.h to channel: * Added POD_Type to the method * Add CHANNEL to executor * Updated get and set DataType to accomodate Channels * Updating get and set to incorporate channels * Adding CHANNEL as supported VarType in protobuf * Removing unecessary import * Fixing VarDesc to adapt to Channel as VarType * Add channel.h to executor * Remove innclude from channel * Updated var_type to support Channel as var type * Adding get_channel to pybind * Added ChannelHolder * Adding make_channel as an op * Adding ChannelHolder in channel * Fixing typo * Commenting out operators in concurrency * Removing totypeid right now since we don't need it. * Reverting python changes * Fixing typo in framework.py * Modify comments for ReaderHolder
1 parent 88c22e9 commit 77ee8fb

File tree

7 files changed

+142
-3
lines changed

7 files changed

+142
-3
lines changed

paddle/fluid/framework/channel.h

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ limitations under the License. */
1515
#pragma once
1616

1717
#include <stddef.h> // for size_t
18+
#include <typeindex>
19+
#include "paddle/fluid/platform/enforce.h"
1820

1921
namespace paddle {
2022
namespace framework {
@@ -51,6 +53,77 @@ void CloseChannel(Channel<T>* ch) {
5153
ch->Close();
5254
}
5355

56+
/*
57+
* The ChannelHolder class serves two main purposes:
58+
* 1. It acts as a unified wrapper for the different kinds of
59+
* channels, i.e. Buffered and Unbuffered channels. This is
60+
* similar to the ReaderHolder class.
61+
* 2. It also helps us in TypeHiding. This is similar to the
62+
* PlaceHolder implementations in variable.h and tensor.h.
63+
*/
64+
class ChannelHolder {
65+
public:
66+
template <typename T>
67+
void Reset(size_t buffer_size) {
68+
holder_.reset(new PlaceholderImpl<T>(buffer_size));
69+
}
70+
71+
template <typename T>
72+
bool Send(T* data) {
73+
if (!IsInitialized()) return false;
74+
PADDLE_ENFORCE_EQ(holder_->Type(), std::type_index(typeid(T)));
75+
// Static cast should be safe because we have ensured that types are same
76+
Channel<T>* channel = static_cast<Channel<T>*>(holder_->Ptr());
77+
return channel != nullptr ? channel->Send(data) : false;
78+
}
79+
80+
template <typename T>
81+
bool Receive(T* data) {
82+
if (!IsInitialized()) return false;
83+
PADDLE_ENFORCE_EQ(holder_->Type(), std::type_index(typeid(T)));
84+
Channel<T>* channel = static_cast<Channel<T>*>(holder_->Ptr());
85+
return channel != nullptr ? channel->Receive(data) : false;
86+
}
87+
88+
void close() {
89+
if (IsInitialized()) holder_->Close();
90+
}
91+
92+
inline bool IsInitialized() const { return holder_ != nullptr; }
93+
94+
private:
95+
/**
96+
* @note Placeholder hides type T, so it doesn't appear as a template
97+
* parameter of ChannelHolder.
98+
*/
99+
struct Placeholder {
100+
virtual ~Placeholder() {}
101+
virtual const std::type_index Type() const = 0;
102+
virtual void* Ptr() const = 0;
103+
virtual void Close() const = 0;
104+
std::type_info type_;
105+
};
106+
107+
template <typename T>
108+
struct PlaceholderImpl : public Placeholder {
109+
PlaceholderImpl(size_t buffer_size) : type_(std::type_index(typeid(T))) {
110+
channel_.reset(MakeChannel<T>(buffer_size));
111+
}
112+
113+
virtual const std::type_index Type() const { return type_; }
114+
virtual void* Ptr() const { return static_cast<void*>(channel_.get()); }
115+
virtual void Close() {
116+
if (channel_) channel_->Close();
117+
}
118+
119+
std::unique_ptr<Channel<T>*> channel_;
120+
const std::type_index type_;
121+
};
122+
123+
// Pointer to a PlaceholderImpl object
124+
std::unique_ptr<Placeholder> holder_;
125+
};
126+
54127
} // namespace framework
55128
} // namespace paddle
56129

paddle/fluid/framework/executor.cc

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ limitations under the License. */
1717
#include <set>
1818

1919
#include "gflags/gflags.h"
20+
#include "paddle/fluid/framework/channel.h"
2021
#include "paddle/fluid/framework/feed_fetch_method.h"
2122
#include "paddle/fluid/framework/feed_fetch_type.h"
2223
#include "paddle/fluid/framework/lod_rank_table.h"
@@ -55,13 +56,15 @@ static void CreateTensor(Variable* var, proto::VarType::Type var_type) {
5556
var->GetMutable<platform::PlaceList>();
5657
} else if (var_type == proto::VarType::READER) {
5758
var->GetMutable<ReaderHolder>();
59+
} else if (var_type == proto::VarType::CHANNEL) {
60+
var->GetMutable<ChannelHolder>();
5861
} else if (var_type == proto::VarType::NCCL_COM) {
5962
// GetMutable will be called in ncclInit
6063
} else {
6164
PADDLE_THROW(
6265
"Variable type %d is not in "
6366
"[LOD_TENSOR, SELECTED_ROWS, FEED_MINIBATCH, FETCH_LIST, "
64-
"LOD_RANK_TABLE, PLACE_LIST, READER, NCCL_COM]",
67+
"LOD_RANK_TABLE, PLACE_LIST, READER, CHANNEL, NCCL_COM]",
6568
var_type);
6669
}
6770
}

paddle/fluid/framework/var_desc.cc

Lines changed: 52 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,13 @@ std::vector<std::vector<int64_t>> VarDesc::GetShapes() const {
8888
}
8989

9090
void VarDesc::SetDataType(proto::VarType::Type data_type) {
91-
mutable_tensor_desc()->set_data_type(data_type);
91+
switch (desc_.type().type()) {
92+
case proto::VarType::CHANNEL:
93+
mutable_channel_desc()->set_data_type(data_type);
94+
break;
95+
default:
96+
mutable_tensor_desc()->set_data_type(data_type);
97+
}
9298
}
9399

94100
void VarDesc::SetDataTypes(
@@ -109,7 +115,13 @@ void VarDesc::SetDataTypes(
109115
}
110116

111117
proto::VarType::Type VarDesc::GetDataType() const {
112-
return tensor_desc().data_type();
118+
switch (desc_.type().type()) {
119+
case proto::VarType::CHANNEL:
120+
return channel_desc().data_type();
121+
break;
122+
default:
123+
return tensor_desc().data_type();
124+
}
113125
}
114126

115127
std::vector<proto::VarType::Type> VarDesc::GetDataTypes() const {
@@ -122,6 +134,17 @@ std::vector<proto::VarType::Type> VarDesc::GetDataTypes() const {
122134
return res;
123135
}
124136

137+
void VarDesc::SetCapacity(int64_t capacity) {
138+
switch (desc_.type().type()) {
139+
case proto::VarType::CHANNEL:
140+
desc_.mutable_type()->mutable_channel()->set_capacity(capacity);
141+
break;
142+
default:
143+
PADDLE_THROW("Setting 'capacity' is not supported by the type of var %s.",
144+
this->Name());
145+
}
146+
}
147+
125148
void VarDesc::SetLoDLevel(int32_t lod_level) {
126149
switch (desc_.type().type()) {
127150
case proto::VarType::LOD_TENSOR:
@@ -191,6 +214,19 @@ std::vector<int32_t> VarDesc::GetLoDLevels() const {
191214
}
192215
}
193216

217+
const proto::VarType::ChannelDesc &VarDesc::channel_desc() const {
218+
PADDLE_ENFORCE(desc_.has_type(), "The var's type hasn't been set.");
219+
PADDLE_ENFORCE(desc_.type().has_type(), "The var type hasn't been set.");
220+
switch (desc_.type().type()) {
221+
case proto::VarType::CHANNEL:
222+
return desc_.type().channel();
223+
default:
224+
PADDLE_THROW(
225+
"Getting 'channel_desc' is not supported by the type of var %s.",
226+
this->Name());
227+
}
228+
}
229+
194230
const proto::VarType::TensorDesc &VarDesc::tensor_desc() const {
195231
PADDLE_ENFORCE(desc_.has_type(), "The var's type hasn't been set.");
196232
PADDLE_ENFORCE(desc_.type().has_type(), "The var type hasn't been set.");
@@ -226,6 +262,20 @@ std::vector<proto::VarType::TensorDesc> VarDesc::tensor_descs() const {
226262
}
227263
}
228264

265+
proto::VarType::ChannelDesc *VarDesc::mutable_channel_desc() {
266+
PADDLE_ENFORCE(desc_.has_type(), "The var type hasn't been set.");
267+
PADDLE_ENFORCE(desc_.type().has_type(), "The var type hasn't been set.");
268+
switch (desc_.type().type()) {
269+
case proto::VarType::CHANNEL:
270+
return desc_.mutable_type()->mutable_channel();
271+
default:
272+
PADDLE_THROW(
273+
"Getting 'mutable_channel_desc' is not supported by the type of var "
274+
"%s.",
275+
this->Name());
276+
}
277+
}
278+
229279
proto::VarType::TensorDesc *VarDesc::mutable_tensor_desc() {
230280
PADDLE_ENFORCE(desc_.has_type(), "The var type hasn't been set.");
231281
PADDLE_ENFORCE(desc_.type().has_type(), "The var type hasn't been set.");

paddle/fluid/framework/var_desc.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,8 @@ class VarDesc {
8585
void SetDataTypes(
8686
const std::vector<proto::VarType::Type> &multiple_data_type);
8787

88+
void SetCapacity(int64_t capacity);
89+
8890
proto::VarType::Type GetDataType() const;
8991

9092
std::vector<proto::VarType::Type> GetDataTypes() const;
@@ -106,8 +108,10 @@ class VarDesc {
106108
void SetPersistable(bool persistable) { desc_.set_persistable(persistable); }
107109

108110
private:
111+
const proto::VarType::ChannelDesc &channel_desc() const;
109112
const proto::VarType::TensorDesc &tensor_desc() const;
110113
std::vector<proto::VarType::TensorDesc> tensor_descs() const;
114+
proto::VarType::ChannelDesc *mutable_channel_desc();
111115
proto::VarType::TensorDesc *mutable_tensor_desc();
112116
std::vector<proto::VarType::TensorDesc *> mutable_tensor_descs();
113117

paddle/fluid/framework/var_type.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

1515
#pragma once
16+
#include "paddle/fluid/framework/channel.h"
1617
#include "paddle/fluid/framework/framework.pb.h"
1718
#include "paddle/fluid/framework/lod_rank_table.h"
1819
#include "paddle/fluid/framework/lod_tensor.h"
@@ -34,6 +35,8 @@ inline proto::VarType::Type ToVarType(std::type_index type) {
3435
return proto::VarType_Type_SELECTED_ROWS;
3536
} else if (type.hash_code() == typeid(ReaderHolder).hash_code()) {
3637
return proto::VarType_Type_READER;
38+
} else if (type.hash_code() == typeid(ChannelHolder).hash_code()) {
39+
return proto::VarType_Type_CHANNEL;
3740
} else {
3841
PADDLE_THROW("ToVarType:Unsupported type %s", type.name());
3942
}
@@ -57,6 +60,9 @@ inline void VisitVarType(const framework::Variable& var, Visitor visitor) {
5760
case proto::VarType_Type_READER:
5861
visitor(var.Get<ReaderHolder>());
5962
return;
63+
case proto::VarType_Type_CHANNEL:
64+
visitor(var.Get<ChannelHolder>());
65+
return;
6066
default:
6167
PADDLE_THROW("Not supported visit type, %d", ToVarType(var.Type()));
6268
}

paddle/fluid/pybind/protobuf.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,7 @@ void BindVarDsec(py::module &m) {
216216
.def("set_shapes", &VarDesc::SetShapes)
217217
.def("set_dtype", &VarDesc::SetDataType)
218218
.def("set_dtypes", &VarDesc::SetDataTypes)
219+
.def("set_capacity", &VarDesc::SetCapacity)
219220
.def("shape", &VarDesc::GetShape, py::return_value_policy::reference)
220221
.def("shapes", &VarDesc::GetShapes, py::return_value_policy::reference)
221222
.def("dtype", &VarDesc::GetDataType, py::return_value_policy::reference)
@@ -246,6 +247,7 @@ void BindVarDsec(py::module &m) {
246247
.value("STEP_SCOPES", proto::VarType::STEP_SCOPES)
247248
.value("LOD_RANK_TABLE", proto::VarType::LOD_RANK_TABLE)
248249
.value("LOD_TENSOR_ARRAY", proto::VarType::LOD_TENSOR_ARRAY)
250+
.value("CHANNEL", proto::VarType::CHANNEL)
249251
.value("PLACE_LIST", proto::VarType::PLACE_LIST)
250252
.value("READER", proto::VarType::READER)
251253
.value("NCCL_COM", proto::VarType::NCCL_COM);

paddle/fluid/pybind/pybind.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ limitations under the License. */
1717
#include <mutex> // for call_once
1818
#include <unordered_map>
1919
#include "paddle/fluid/framework/backward.h"
20+
#include "paddle/fluid/framework/channel.h"
2021
#include "paddle/fluid/framework/executor.h"
2122
#include "paddle/fluid/framework/feed_fetch_method.h"
2223
#include "paddle/fluid/framework/framework.pb.h"

0 commit comments

Comments
 (0)