Skip to content

Commit 78c884d

Browse files
authored
Redesign channel implementation for Select Op (#8814)
* Redesign channel implementation for Select Op * Remove unecessary header * Remove unnecessary comments
1 parent 351795e commit 78c884d

File tree

5 files changed

+320
-389
lines changed

5 files changed

+320
-389
lines changed

paddle/fluid/framework/channel.h

Lines changed: 38 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -28,24 +28,19 @@ class Channel {
2828
virtual bool Send(T*) = 0;
2929
virtual bool Receive(T*) = 0;
3030
virtual size_t Cap() = 0;
31+
virtual void Lock() = 0;
32+
virtual void Unlock() = 0;
3133
virtual void Close() = 0;
3234
virtual ~Channel() {}
3335
};
3436

3537
// Forward declaration of channel implementations.
36-
namespace details {
3738
template <typename T>
38-
class Buffered;
39-
template <typename T>
40-
class UnBuffered;
41-
} // namespace details
39+
class ChannelImpl;
4240

4341
template <typename T>
4442
Channel<T>* MakeChannel(size_t buffer_size) {
45-
if (buffer_size > 0) {
46-
return new details::Buffered<T>(buffer_size);
47-
}
48-
return new details::UnBuffered<T>();
43+
return new ChannelImpl<T>(buffer_size);
4944
}
5045

5146
template <typename T>
@@ -89,6 +84,19 @@ class ChannelHolder {
8984
if (IsInitialized()) holder_->Close();
9085
}
9186

87+
size_t Cap() {
88+
if (IsInitialized()) return holder_->Cap();
89+
return -1;
90+
}
91+
92+
void Lock() {
93+
if (IsInitialized()) holder_->Lock();
94+
}
95+
96+
void Unlock() {
97+
if (IsInitialized()) holder_->Unlock();
98+
}
99+
92100
inline bool IsInitialized() const { return holder_ != nullptr; }
93101

94102
inline const std::type_index Type() {
@@ -106,6 +114,9 @@ class ChannelHolder {
106114
virtual const std::type_index Type() const = 0;
107115
virtual void* Ptr() const = 0;
108116
virtual void Close() = 0;
117+
virtual void Lock() = 0;
118+
virtual void Unlock() = 0;
119+
virtual size_t Cap() = 0;
109120
};
110121

111122
template <typename T>
@@ -115,11 +126,28 @@ class ChannelHolder {
115126
}
116127

117128
virtual const std::type_index Type() const { return type_; }
129+
118130
virtual void* Ptr() const { return static_cast<void*>(channel_.get()); }
131+
119132
virtual void Close() {
120133
if (channel_) channel_->Close();
121134
}
122135

136+
virtual size_t Cap() {
137+
if (channel_)
138+
return channel_->Cap();
139+
else
140+
return -1;
141+
}
142+
143+
virtual void Lock() {
144+
if (channel_) channel_->Lock();
145+
}
146+
147+
virtual void Unlock() {
148+
if (channel_) channel_->Unlock();
149+
}
150+
123151
std::unique_ptr<Channel<T>> channel_;
124152
const std::type_index type_;
125153
};
@@ -131,5 +159,4 @@ class ChannelHolder {
131159
} // namespace framework
132160
} // namespace paddle
133161

134-
#include "paddle/fluid/framework/details/buffered_channel.h"
135-
#include "paddle/fluid/framework/details/unbuffered_channel.h"
162+
#include "paddle/fluid/framework/channel_impl.h"

paddle/fluid/framework/channel_impl.h

Lines changed: 229 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,229 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
#include <stddef.h> // for size_t
17+
#include <atomic>
18+
#include <condition_variable>
19+
#include <deque>
20+
#include "paddle/fluid/framework/channel.h"
21+
#include "paddle/fluid/platform/enforce.h"
22+
23+
namespace paddle {
24+
namespace framework {
25+
26+
template <typename T>
27+
class ChannelImpl : public paddle::framework::Channel<T> {
28+
friend Channel<T> *paddle::framework::MakeChannel<T>(size_t);
29+
friend void paddle::framework::CloseChannel<T>(Channel<T> *);
30+
31+
public:
32+
virtual bool Send(T *);
33+
virtual bool Receive(T *);
34+
virtual size_t Cap() { return cap_; }
35+
virtual void Lock();
36+
virtual void Unlock();
37+
virtual void Close();
38+
39+
ChannelImpl(size_t);
40+
virtual ~ChannelImpl();
41+
42+
private:
43+
struct QueueMessage {
44+
T *data;
45+
std::condition_variable_any cond;
46+
bool chan_closed = false;
47+
bool completed = false;
48+
49+
QueueMessage(T *item) : data(item) {}
50+
51+
void Wait(std::unique_lock<std::recursive_mutex> &lock) {
52+
cond.wait(lock, [this]() { return completed; });
53+
}
54+
55+
void Notify() {
56+
completed = true;
57+
cond.notify_all();
58+
}
59+
};
60+
61+
bool send_return(bool value) {
62+
send_ctr--;
63+
destructor_cond_.notify_all();
64+
return value;
65+
}
66+
67+
bool recv_return(bool value) {
68+
recv_ctr--;
69+
destructor_cond_.notify_all();
70+
return value;
71+
}
72+
73+
size_t cap_;
74+
std::recursive_mutex mu_;
75+
bool closed_;
76+
std::deque<T> buf_;
77+
std::deque<std::shared_ptr<QueueMessage>> recvq;
78+
std::deque<std::shared_ptr<QueueMessage>> sendq;
79+
std::atomic<unsigned> send_ctr{0};
80+
std::atomic<unsigned> recv_ctr{0};
81+
std::condition_variable_any destructor_cond_;
82+
};
83+
84+
template <typename T>
85+
ChannelImpl<T>::ChannelImpl(size_t capacity)
86+
: cap_(capacity), closed_(false), send_ctr(0), recv_ctr(0) {
87+
PADDLE_ENFORCE_GE(capacity, 0);
88+
}
89+
90+
template <typename T>
91+
bool ChannelImpl<T>::Send(T *item) {
92+
send_ctr++;
93+
std::unique_lock<std::recursive_mutex> lock{mu_};
94+
95+
// If channel is closed, do nothing
96+
if (closed_) {
97+
lock.unlock();
98+
// TODO(abhinavarora) Should panic on closed channel
99+
return send_return(false);
100+
}
101+
102+
// If there is a receiver, directly pass the value we want
103+
// to send to the receiver, bypassing the channel buffer if any
104+
if (!recvq.empty()) {
105+
std::shared_ptr<QueueMessage> m = recvq.front();
106+
recvq.pop_front();
107+
// Do the data transfer
108+
*(m->data) = std::move(*item);
109+
// Wake up the blocked process and unlock
110+
m->Notify();
111+
lock.unlock();
112+
return send_return(true);
113+
}
114+
115+
// Unbuffered channel will always bypass this
116+
// If buffered channel has space in buffer,
117+
// write the element to the buffer.
118+
if (buf_.size() < cap_) {
119+
// Copy to buffer
120+
buf_.push_back(std::move(*item));
121+
// Release lock and return true
122+
lock.unlock();
123+
return send_return(true);
124+
}
125+
126+
// Block on channel, because some receiver will complete
127+
// the operation for us
128+
auto m = std::make_shared<QueueMessage>(item);
129+
sendq.push_back(m);
130+
m->Wait(lock);
131+
// TODO(abhinavarora) Should panic on closed channel
132+
return send_return(!m->chan_closed);
133+
}
134+
135+
template <typename T>
136+
bool ChannelImpl<T>::Receive(T *item) {
137+
recv_ctr++;
138+
std::unique_lock<std::recursive_mutex> lock{mu_};
139+
140+
// If channel is closed and buffer is empty or
141+
// channel is unbuffered
142+
if (closed_ && buf_.empty()) {
143+
lock.unlock();
144+
return recv_return(false);
145+
}
146+
147+
// If there is a sender, directly receive the value we want
148+
// from the sender, bypassing the channel buffer if any
149+
if (!sendq.empty()) {
150+
std::shared_ptr<QueueMessage> m = sendq.front();
151+
sendq.pop_front();
152+
// Do the data transfer
153+
*item = std::move(*(m->data));
154+
// Wake up the blocked process and unlock
155+
m->Notify();
156+
lock.unlock();
157+
return recv_return(true);
158+
}
159+
160+
// If this is a buffered channel and there are items in buffer
161+
if (buf_.size() > 0) {
162+
// Directly read from buffer
163+
*item = std::move(buf_.front());
164+
buf_.pop_front();
165+
// Release lock and return true
166+
lock.unlock();
167+
return recv_return(true);
168+
}
169+
170+
// No sender available, block on this channel
171+
// Some receiver will complete the option for us
172+
auto m = std::make_shared<QueueMessage>(item);
173+
recvq.push_back(m);
174+
m->Wait(lock);
175+
176+
return recv_return(!m->chan_closed);
177+
}
178+
179+
template <typename T>
180+
void ChannelImpl<T>::Lock() {
181+
mu_.lock();
182+
}
183+
184+
template <typename T>
185+
void ChannelImpl<T>::Unlock() {
186+
mu_.unlock();
187+
}
188+
189+
template <typename T>
190+
void ChannelImpl<T>::Close() {
191+
std::unique_lock<std::recursive_mutex> lock{mu_};
192+
193+
if (closed_) {
194+
// TODO(abhinavarora): closing an already closed channel should panic
195+
lock.unlock();
196+
return;
197+
}
198+
199+
closed_ = true;
200+
201+
// Empty the readers
202+
while (!recvq.empty()) {
203+
std::shared_ptr<QueueMessage> m = recvq.front();
204+
recvq.pop_front();
205+
m->chan_closed = true;
206+
m->Notify();
207+
}
208+
209+
// Empty the senders
210+
while (!sendq.empty()) {
211+
std::shared_ptr<QueueMessage> m = sendq.front();
212+
sendq.pop_front();
213+
m->chan_closed = true;
214+
m->Notify();
215+
}
216+
}
217+
218+
template <typename T>
219+
ChannelImpl<T>::~ChannelImpl() {
220+
Close();
221+
// The destructor must wait for all readers and writers to complete their task
222+
// The channel has been closed, so we will not accept new readers and writers
223+
std::unique_lock<std::recursive_mutex> lock{mu_};
224+
destructor_cond_.wait(lock,
225+
[this]() { return send_ctr == 0 && recv_ctr == 0; });
226+
}
227+
228+
} // namespace framework
229+
} // namespace paddle

0 commit comments

Comments
 (0)