Skip to content

Commit 1e4c504

Browse files
cs2beabhinavarora
authored andcommitted
Implement Select OP (#9088)
* Fix old documentation for channel_recv * Initial design of CSP select * Redesign channel implementation for Select Op * Remove unecessary header * Initial checkin of select op, currently will read all the conditional_op in the cases block and also pull out all channels involved in the select. * Init python select op API * Python select bug fix when checking op creates block * Add case_to_execute as (a) input to select, (b) into the passed inputs into the select op * Add in addition code for select op * Init fibonacci test from python * implement fibonnaci sequence test * update fib unit test * Improve select test cases * Shorten non-pep-8-ed lines * Add methods on channel needed by select op * Fix compile issues, finish implementation, still need to debug code * Fix issue with fibonncci test, it works now! * Change QueueMessage callback to take in an ChannelAction enum, fix select unit test * Fix case attributes * Fix issue with select control flow * Make cases - previously on each selectcase conditional_block - attributes to select * Use class constants for type of channel * Change select op to take in "cases" attribute * return boolean from select callback function to tell Channel if this RECV or SEND should be executed * Improve attributes and inputs comments on select op * Fix issues with python unit test * Assert fibonacci final output * Fix issue when channel name / channel var is null for "default" case in select op * Assert base select test output * Make QueueMessage use shared pointer and modify the order of the callback * Fixing the order in which the callback is called * Move channel utility methods to paddle/fluid/operators/concurrency/channel_util * Create channel_util and move channel util methods * Fix crash when calling select_op * Fix deadlock * Fix issue of channel destructor deadlock * Fix precommit issues * Accidentally checked in changes to beam_search_op, reverting change. * Fix dependency issue in concurrency cmake * add device_context dependency for concurrency target
1 parent 45073b7 commit 1e4c504

File tree

16 files changed

+1096
-91
lines changed

16 files changed

+1096
-91
lines changed

paddle/fluid/framework/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,4 +103,5 @@ cc_test(cow_ptr_tests SRCS details/cow_ptr_test.cc)
103103
cc_test(channel_test SRCS channel_test.cc)
104104
cc_test(tuple_test SRCS tuple_test.cc )
105105
cc_test(concurrency_test SRCS concurrency_test.cc DEPS go_op channel_close_op channel_create_op
106-
channel_send_op channel_recv_op sum_op elementwise_add_op executor proto_desc)
106+
channel_send_op channel_recv_op sum_op select_op elementwise_add_op compare_op
107+
conditional_block_op while_op assign_op print_op executor proto_desc)

paddle/fluid/framework/channel.h

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -162,24 +162,12 @@ class ChannelHolder {
162162
}
163163
}
164164

165-
template <typename T>
166165
void RemoveFromSendQ(const void* referrer) {
167-
if (IsInitialized()) {
168-
Channel<T>* channel = static_cast<Channel<T>*>(holder_->Ptr());
169-
if (channel != nullptr) {
170-
channel->RemoveFromSendQ(referrer);
171-
}
172-
}
166+
if (IsInitialized()) holder_->RemoveFromSendQ(referrer);
173167
}
174168

175-
template <typename T>
176169
void RemoveFromReceiveQ(const void* referrer) {
177-
if (IsInitialized()) {
178-
Channel<T>* channel = static_cast<Channel<T>*>(holder_->Ptr());
179-
if (channel != nullptr) {
180-
channel->RemoveFromReceiveQ(referrer);
181-
}
182-
}
170+
if (IsInitialized()) holder_->RemoveFromReceiveQ(referrer);
183171
}
184172

185173
inline bool IsInitialized() const { return holder_ != nullptr; }
@@ -201,6 +189,8 @@ class ChannelHolder {
201189
virtual bool IsClosed() = 0;
202190
virtual bool CanSend() = 0;
203191
virtual bool CanReceive() = 0;
192+
virtual void RemoveFromSendQ(const void* referrer) = 0;
193+
virtual void RemoveFromReceiveQ(const void* referrer) = 0;
204194
virtual void Close() = 0;
205195
virtual void Lock() = 0;
206196
virtual void Unlock() = 0;
@@ -238,6 +228,18 @@ class ChannelHolder {
238228
return false;
239229
}
240230

231+
virtual void RemoveFromSendQ(const void* referrer) {
232+
if (channel_) {
233+
channel_->RemoveFromSendQ(referrer);
234+
}
235+
}
236+
237+
virtual void RemoveFromReceiveQ(const void* referrer) {
238+
if (channel_) {
239+
channel_->RemoveFromReceiveQ(referrer);
240+
}
241+
}
242+
241243
virtual void Close() {
242244
if (channel_) channel_->Close();
243245
}

paddle/fluid/framework/channel_impl.h

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ bool ChannelImpl<T>::Send(T *item) {
151151
// We do not care about notifying other
152152
// because they would have been notified
153153
// by the executed select case.
154-
return Send(item);
154+
return send_return(Send(item));
155155

156156
// Wake up the blocked process and unlock
157157
m->Notify();
@@ -214,7 +214,7 @@ bool ChannelImpl<T>::Receive(T *item) {
214214
// We do not care about notifying other
215215
// because they would have been notified
216216
// by the executed select case.
217-
return Receive(item);
217+
return recv_return(Receive(item));
218218

219219
// Wake up the blocked process and unlock
220220
m->Notify();
@@ -331,7 +331,6 @@ void ChannelImpl<T>::RemoveFromSendQ(const void *referrer) {
331331

332332
if (sendMsg->referrer == referrer) {
333333
it = sendq.erase(it);
334-
send_ctr--;
335334
} else {
336335
++it;
337336
}
@@ -347,7 +346,6 @@ void ChannelImpl<T>::RemoveFromReceiveQ(const void *referrer) {
347346

348347
if (recvMsg->referrer == referrer) {
349348
it = recvq.erase(it);
350-
recv_ctr--;
351349
} else {
352350
++it;
353351
}

paddle/fluid/framework/concurrency_test.cc

Lines changed: 189 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,19 @@ limitations under the License. */
1919
#include "paddle/fluid/framework/channel.h"
2020
#include "paddle/fluid/framework/executor.h"
2121
#include "paddle/fluid/framework/op_registry.h"
22-
#include "paddle/fluid/framework/program_desc.h"
2322

2423
USE_NO_KERNEL_OP(go);
2524
USE_NO_KERNEL_OP(channel_close);
2625
USE_NO_KERNEL_OP(channel_create);
2726
USE_NO_KERNEL_OP(channel_recv);
2827
USE_NO_KERNEL_OP(channel_send);
2928
USE_NO_KERNEL_OP(elementwise_add);
29+
USE_NO_KERNEL_OP(select);
30+
USE_NO_KERNEL_OP(conditional_block);
31+
USE_NO_KERNEL_OP(equal);
32+
USE_NO_KERNEL_OP(assign);
33+
USE_NO_KERNEL_OP(while);
34+
USE_NO_KERNEL_OP(print);
3035

3136
namespace f = paddle::framework;
3237
namespace p = paddle::platform;
@@ -35,27 +40,15 @@ namespace paddle {
3540
namespace framework {
3641

3742
template <typename T>
38-
void CreateIntVariable(Scope &scope, p::CPUPlace &place, std::string name,
39-
T value) {
40-
// Create LoDTensor<int> of dim [1,1]
43+
LoDTensor *CreateVariable(Scope &scope, p::CPUPlace &place, std::string name,
44+
T value) {
45+
// Create LoDTensor<int> of dim [1]
4146
auto var = scope.Var(name);
4247
auto tensor = var->GetMutable<LoDTensor>();
43-
tensor->Resize({1, 1});
48+
tensor->Resize({1});
4449
T *expect = tensor->mutable_data<T>(place);
4550
expect[0] = value;
46-
}
47-
48-
void InitTensorsInScope(Scope &scope, p::CPUPlace &place) {
49-
p::CPUDeviceContext ctx(place);
50-
51-
// Create channel variable
52-
scope.Var("Channel");
53-
54-
// Create Variables, x0 will be put into channel,
55-
// result will be pulled from channel
56-
CreateIntVariable(scope, place, "Status", false);
57-
CreateIntVariable(scope, place, "x0", 99);
58-
CreateIntVariable(scope, place, "result", 0);
51+
return tensor;
5952
}
6053

6154
void AddOp(const std::string &type, const VariableNameMap &inputs,
@@ -73,12 +66,116 @@ void AddOp(const std::string &type, const VariableNameMap &inputs,
7366
op->SetAttrMap(attrs);
7467
}
7568

69+
void AddCase(ProgramDesc *program, Scope *scope, p::CPUPlace *place,
70+
BlockDesc *casesBlock, int caseId, int caseType,
71+
std::string caseChannel, std::string caseVarName,
72+
std::function<void(BlockDesc *, Scope *)> func) {
73+
std::string caseCondName = std::string("caseCond") + std::to_string(caseId);
74+
std::string caseCondXVarName =
75+
std::string("caseCondX") + std::to_string(caseId);
76+
77+
BlockDesc *caseBlock = program->AppendBlock(*casesBlock);
78+
func(caseBlock, scope);
79+
80+
CreateVariable(*scope, *place, caseCondName, false);
81+
CreateVariable(*scope, *place, caseCondXVarName, caseId);
82+
CreateVariable(*scope, *place, caseVarName, caseId);
83+
84+
scope->Var("step_scope");
85+
86+
AddOp("equal", {{"X", {caseCondXVarName}}, {"Y", {"caseToExecute"}}},
87+
{{"Out", {caseCondName}}}, {}, casesBlock);
88+
89+
AddOp("conditional_block", {{"X", {caseCondName}}, {"Params", {}}},
90+
{{"Out", {}}, {"Scope", {"step_scope"}}},
91+
{{"sub_block", caseBlock}, {"is_scalar_condition", true}}, casesBlock);
92+
}
93+
94+
void AddFibonacciSelect(Scope *scope, p::CPUPlace *place, ProgramDesc *program,
95+
BlockDesc *parentBlock, std::string dataChanName,
96+
std::string quitChanName) {
97+
BlockDesc *whileBlock = program->AppendBlock(*parentBlock);
98+
99+
CreateVariable(*scope, *place, "whileExitCond", true);
100+
CreateVariable(*scope, *place, "caseToExecute", -1);
101+
CreateVariable(*scope, *place, "case1var", 0);
102+
103+
CreateVariable(*scope, *place, "xtemp", 0);
104+
105+
// TODO(thuan): Need to create fibXToSend, since channel send moves the actual
106+
// data,
107+
// which causes the data to be no longer accessible to do the fib calculation
108+
// TODO(abhinav): Change channel send to do a copy instead of a move!
109+
CreateVariable(*scope, *place, "fibXToSend", 0);
110+
111+
CreateVariable(*scope, *place, "fibX", 0);
112+
CreateVariable(*scope, *place, "fibY", 1);
113+
CreateVariable(*scope, *place, "quitVar", 0);
114+
115+
BlockDesc *casesBlock = program->AppendBlock(*whileBlock);
116+
std::function<void(BlockDesc * caseBlock)> f = [](BlockDesc *caseBlock) {};
117+
118+
// TODO(thuan): Remove this once we change channel send to do a copy instead
119+
// of move
120+
AddOp("assign", {{"X", {"fibX"}}}, {{"Out", {"fibXToSend"}}}, {}, whileBlock);
121+
122+
// Case 0: Send to dataChanName
123+
std::function<void(BlockDesc * caseBlock, Scope * scope)> case0Func = [&](
124+
BlockDesc *caseBlock, Scope *scope) {
125+
AddOp("assign", {{"X", {"fibX"}}}, {{"Out", {"xtemp"}}}, {}, caseBlock);
126+
AddOp("assign", {{"X", {"fibY"}}}, {{"Out", {"fibX"}}}, {}, caseBlock);
127+
AddOp("elementwise_add", {{"X", {"xtemp"}}, {"Y", {"fibY"}}},
128+
{{"Out", {"fibY"}}}, {}, caseBlock);
129+
};
130+
AddCase(program, scope, place, casesBlock, 0, 1, dataChanName, "fibXToSend",
131+
case0Func);
132+
std::string case0Config =
133+
std::string("0,1,") + dataChanName + std::string(",fibXToSend");
134+
135+
// Case 1: Receive from quitChanName
136+
std::function<void(BlockDesc * caseBlock, Scope * scope)> case2Func = [&](
137+
BlockDesc *caseBlock, Scope *scope) {
138+
// Exit the while loop after we receive from quit channel.
139+
// We assign a false to "whileExitCond" variable, which will
140+
// break out of while_op loop
141+
CreateVariable(*scope, *place, "whileFalse", false);
142+
AddOp("assign", {{"X", {"whileFalse"}}}, {{"Out", {"whileExitCond"}}}, {},
143+
caseBlock);
144+
};
145+
AddCase(program, scope, place, casesBlock, 1, 2, quitChanName, "quitVar",
146+
case2Func);
147+
std::string case1Config =
148+
std::string("1,2,") + quitChanName + std::string(",quitVar");
149+
150+
// Select block
151+
AddOp("select", {{"X", {dataChanName, quitChanName}},
152+
{"case_to_execute", {"caseToExecute"}}},
153+
{}, {{"sub_block", casesBlock},
154+
{"cases", std::vector<std::string>{case0Config, case1Config}}},
155+
whileBlock);
156+
157+
scope->Var("stepScopes");
158+
AddOp("while",
159+
{{"X", {dataChanName, quitChanName}}, {"Condition", {"whileExitCond"}}},
160+
{{"Out", {}}, {"StepScopes", {"stepScopes"}}},
161+
{{"sub_block", whileBlock}}, parentBlock);
162+
}
163+
76164
TEST(Concurrency, Go_Op) {
77165
Scope scope;
78166
p::CPUPlace place;
79167

80168
// Initialize scope variables
81-
InitTensorsInScope(scope, place);
169+
p::CPUDeviceContext ctx(place);
170+
171+
// Create channel variable
172+
scope.Var("Channel");
173+
174+
// Create Variables, x0 will be put into channel,
175+
// result will be pulled from channel
176+
CreateVariable(scope, place, "Status", false);
177+
CreateVariable(scope, place, "x0", 99);
178+
CreateVariable(scope, place, "result", 0);
82179

83180
framework::Executor executor(place);
84181
ProgramDesc program;
@@ -118,5 +215,78 @@ TEST(Concurrency, Go_Op) {
118215
auto *finalData = tensor.data<int>();
119216
EXPECT_EQ(finalData[0], 99);
120217
}
218+
219+
/**
220+
* This test implements the fibonacci function using go_op and select_op
221+
*/
222+
TEST(Concurrency, Select) {
223+
Scope scope;
224+
p::CPUPlace place;
225+
226+
// Initialize scope variables
227+
p::CPUDeviceContext ctx(place);
228+
229+
CreateVariable(scope, place, "Status", false);
230+
CreateVariable(scope, place, "result", 0);
231+
CreateVariable(scope, place, "currentXFib", 0);
232+
233+
framework::Executor executor(place);
234+
ProgramDesc program;
235+
BlockDesc *block = program.MutableBlock(0);
236+
237+
// Create channel OP
238+
std::string dataChanName = "Channel";
239+
scope.Var(dataChanName);
240+
AddOp("channel_create", {}, {{"Out", {dataChanName}}},
241+
{{"capacity", 0}, {"data_type", f::proto::VarType::LOD_TENSOR}}, block);
242+
243+
std::string quitChanName = "Quit";
244+
scope.Var(quitChanName);
245+
AddOp("channel_create", {}, {{"Out", {quitChanName}}},
246+
{{"capacity", 0}, {"data_type", f::proto::VarType::LOD_TENSOR}}, block);
247+
248+
// Create Go Op routine, which loops 10 times over fibonacci sequence
249+
CreateVariable(scope, place, "xReceiveVar", 0);
250+
251+
BlockDesc *goOpBlock = program.AppendBlock(program.Block(0));
252+
for (int i = 0; i < 10; ++i) {
253+
AddOp("channel_recv", {{"Channel", {dataChanName}}},
254+
{{"Status", {"Status"}}, {"Out", {"currentXFib"}}}, {}, goOpBlock);
255+
AddOp("print", {{"In", {"currentXFib"}}}, {{"Out", {"currentXFib"}}},
256+
{{"first_n", 100},
257+
{"summarize", -1},
258+
{"print_tensor_name", false},
259+
{"print_tensor_type", true},
260+
{"print_tensor_shape", false},
261+
{"print_tensor_lod", false},
262+
{"print_phase", std::string("FORWARD")},
263+
{"message", std::string("X: ")}},
264+
goOpBlock);
265+
}
266+
267+
CreateVariable(scope, place, "quitSignal", 0);
268+
AddOp("channel_send", {{"Channel", {quitChanName}}, {"X", {"quitSignal"}}},
269+
{{"Status", {"Status"}}}, {}, goOpBlock);
270+
271+
// Create Go Op
272+
AddOp("go", {{"X", {dataChanName, quitChanName}}}, {},
273+
{{"sub_block", goOpBlock}}, block);
274+
275+
AddFibonacciSelect(&scope, &place, &program, block, dataChanName,
276+
quitChanName);
277+
278+
// Create Channel Close Op
279+
AddOp("channel_close", {{"Channel", {dataChanName}}}, {}, {}, block);
280+
AddOp("channel_close", {{"Channel", {quitChanName}}}, {}, {}, block);
281+
282+
executor.Run(program, &scope, 0, true, true);
283+
284+
// After we call executor.run, "result" variable should be equal to 34
285+
// (which is 10 loops through fibonacci sequence)
286+
const LoDTensor &tensor = (scope.FindVar("currentXFib"))->Get<LoDTensor>();
287+
auto *finalData = tensor.data<int>();
288+
EXPECT_EQ(finalData[0], 34);
289+
}
290+
121291
} // namespace framework
122292
} // namespace paddle

paddle/fluid/operators/CMakeLists.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,11 @@ op_library(save_combine_op DEPS lod_tensor)
203203
op_library(load_combine_op DEPS lod_tensor)
204204
op_library(concat_op DEPS concat)
205205

206+
# FIXME(thuan): Move CSP operators to paddle/fluid/framework/operators/concurrency
207+
add_subdirectory(concurrency)
208+
op_library(channel_send_op DEPS concurrency)
209+
op_library(channel_recv_op DEPS concurrency)
210+
206211
list(REMOVE_ITEM GENERAL_OPS ${DEPS_OPS})
207212
foreach(src ${GENERAL_OPS})
208213
op_library(${src})

0 commit comments

Comments
 (0)