@@ -19,14 +19,19 @@ limitations under the License. */
19
19
#include " paddle/fluid/framework/channel.h"
20
20
#include " paddle/fluid/framework/executor.h"
21
21
#include " paddle/fluid/framework/op_registry.h"
22
- #include " paddle/fluid/framework/program_desc.h"
23
22
24
23
USE_NO_KERNEL_OP (go);
25
24
USE_NO_KERNEL_OP (channel_close);
26
25
USE_NO_KERNEL_OP (channel_create);
27
26
USE_NO_KERNEL_OP (channel_recv);
28
27
USE_NO_KERNEL_OP (channel_send);
29
28
USE_NO_KERNEL_OP (elementwise_add);
29
+ USE_NO_KERNEL_OP (select);
30
+ USE_NO_KERNEL_OP (conditional_block);
31
+ USE_NO_KERNEL_OP (equal);
32
+ USE_NO_KERNEL_OP (assign);
33
+ USE_NO_KERNEL_OP (while );
34
+ USE_NO_KERNEL_OP (print);
30
35
31
36
namespace f = paddle::framework;
32
37
namespace p = paddle::platform;
@@ -35,27 +40,15 @@ namespace paddle {
35
40
namespace framework {
36
41
37
42
template <typename T>
38
- void CreateIntVariable (Scope &scope, p::CPUPlace &place, std::string name,
39
- T value) {
40
- // Create LoDTensor<int> of dim [1,1 ]
43
+ LoDTensor * CreateVariable (Scope &scope, p::CPUPlace &place, std::string name,
44
+ T value) {
45
+ // Create LoDTensor<int> of dim [1]
41
46
auto var = scope.Var (name);
42
47
auto tensor = var->GetMutable <LoDTensor>();
43
- tensor->Resize ({1 , 1 });
48
+ tensor->Resize ({1 });
44
49
T *expect = tensor->mutable_data <T>(place);
45
50
expect[0 ] = value;
46
- }
47
-
48
- void InitTensorsInScope (Scope &scope, p::CPUPlace &place) {
49
- p::CPUDeviceContext ctx (place);
50
-
51
- // Create channel variable
52
- scope.Var (" Channel" );
53
-
54
- // Create Variables, x0 will be put into channel,
55
- // result will be pulled from channel
56
- CreateIntVariable (scope, place, " Status" , false );
57
- CreateIntVariable (scope, place, " x0" , 99 );
58
- CreateIntVariable (scope, place, " result" , 0 );
51
+ return tensor;
59
52
}
60
53
61
54
void AddOp (const std::string &type, const VariableNameMap &inputs,
@@ -73,12 +66,116 @@ void AddOp(const std::string &type, const VariableNameMap &inputs,
73
66
op->SetAttrMap (attrs);
74
67
}
75
68
69
+ void AddCase (ProgramDesc *program, Scope *scope, p::CPUPlace *place,
70
+ BlockDesc *casesBlock, int caseId, int caseType,
71
+ std::string caseChannel, std::string caseVarName,
72
+ std::function<void (BlockDesc *, Scope *)> func) {
73
+ std::string caseCondName = std::string (" caseCond" ) + std::to_string (caseId);
74
+ std::string caseCondXVarName =
75
+ std::string (" caseCondX" ) + std::to_string (caseId);
76
+
77
+ BlockDesc *caseBlock = program->AppendBlock (*casesBlock);
78
+ func (caseBlock, scope);
79
+
80
+ CreateVariable (*scope, *place, caseCondName, false );
81
+ CreateVariable (*scope, *place, caseCondXVarName, caseId);
82
+ CreateVariable (*scope, *place, caseVarName, caseId);
83
+
84
+ scope->Var (" step_scope" );
85
+
86
+ AddOp (" equal" , {{" X" , {caseCondXVarName}}, {" Y" , {" caseToExecute" }}},
87
+ {{" Out" , {caseCondName}}}, {}, casesBlock);
88
+
89
+ AddOp (" conditional_block" , {{" X" , {caseCondName}}, {" Params" , {}}},
90
+ {{" Out" , {}}, {" Scope" , {" step_scope" }}},
91
+ {{" sub_block" , caseBlock}, {" is_scalar_condition" , true }}, casesBlock);
92
+ }
93
+
94
+ void AddFibonacciSelect (Scope *scope, p::CPUPlace *place, ProgramDesc *program,
95
+ BlockDesc *parentBlock, std::string dataChanName,
96
+ std::string quitChanName) {
97
+ BlockDesc *whileBlock = program->AppendBlock (*parentBlock);
98
+
99
+ CreateVariable (*scope, *place, " whileExitCond" , true );
100
+ CreateVariable (*scope, *place, " caseToExecute" , -1 );
101
+ CreateVariable (*scope, *place, " case1var" , 0 );
102
+
103
+ CreateVariable (*scope, *place, " xtemp" , 0 );
104
+
105
+ // TODO(thuan): Need to create fibXToSend, since channel send moves the actual
106
+ // data,
107
+ // which causes the data to be no longer accessible to do the fib calculation
108
+ // TODO(abhinav): Change channel send to do a copy instead of a move!
109
+ CreateVariable (*scope, *place, " fibXToSend" , 0 );
110
+
111
+ CreateVariable (*scope, *place, " fibX" , 0 );
112
+ CreateVariable (*scope, *place, " fibY" , 1 );
113
+ CreateVariable (*scope, *place, " quitVar" , 0 );
114
+
115
+ BlockDesc *casesBlock = program->AppendBlock (*whileBlock);
116
+ std::function<void (BlockDesc * caseBlock)> f = [](BlockDesc *caseBlock) {};
117
+
118
+ // TODO(thuan): Remove this once we change channel send to do a copy instead
119
+ // of move
120
+ AddOp (" assign" , {{" X" , {" fibX" }}}, {{" Out" , {" fibXToSend" }}}, {}, whileBlock);
121
+
122
+ // Case 0: Send to dataChanName
123
+ std::function<void (BlockDesc * caseBlock, Scope * scope)> case0Func = [&](
124
+ BlockDesc *caseBlock, Scope *scope) {
125
+ AddOp (" assign" , {{" X" , {" fibX" }}}, {{" Out" , {" xtemp" }}}, {}, caseBlock);
126
+ AddOp (" assign" , {{" X" , {" fibY" }}}, {{" Out" , {" fibX" }}}, {}, caseBlock);
127
+ AddOp (" elementwise_add" , {{" X" , {" xtemp" }}, {" Y" , {" fibY" }}},
128
+ {{" Out" , {" fibY" }}}, {}, caseBlock);
129
+ };
130
+ AddCase (program, scope, place, casesBlock, 0 , 1 , dataChanName, " fibXToSend" ,
131
+ case0Func);
132
+ std::string case0Config =
133
+ std::string (" 0,1," ) + dataChanName + std::string (" ,fibXToSend" );
134
+
135
+ // Case 1: Receive from quitChanName
136
+ std::function<void (BlockDesc * caseBlock, Scope * scope)> case2Func = [&](
137
+ BlockDesc *caseBlock, Scope *scope) {
138
+ // Exit the while loop after we receive from quit channel.
139
+ // We assign a false to "whileExitCond" variable, which will
140
+ // break out of while_op loop
141
+ CreateVariable (*scope, *place, " whileFalse" , false );
142
+ AddOp (" assign" , {{" X" , {" whileFalse" }}}, {{" Out" , {" whileExitCond" }}}, {},
143
+ caseBlock);
144
+ };
145
+ AddCase (program, scope, place, casesBlock, 1 , 2 , quitChanName, " quitVar" ,
146
+ case2Func);
147
+ std::string case1Config =
148
+ std::string (" 1,2," ) + quitChanName + std::string (" ,quitVar" );
149
+
150
+ // Select block
151
+ AddOp (" select" , {{" X" , {dataChanName, quitChanName}},
152
+ {" case_to_execute" , {" caseToExecute" }}},
153
+ {}, {{" sub_block" , casesBlock},
154
+ {" cases" , std::vector<std::string>{case0Config, case1Config}}},
155
+ whileBlock);
156
+
157
+ scope->Var (" stepScopes" );
158
+ AddOp (" while" ,
159
+ {{" X" , {dataChanName, quitChanName}}, {" Condition" , {" whileExitCond" }}},
160
+ {{" Out" , {}}, {" StepScopes" , {" stepScopes" }}},
161
+ {{" sub_block" , whileBlock}}, parentBlock);
162
+ }
163
+
76
164
TEST (Concurrency, Go_Op) {
77
165
Scope scope;
78
166
p::CPUPlace place;
79
167
80
168
// Initialize scope variables
81
- InitTensorsInScope (scope, place);
169
+ p::CPUDeviceContext ctx (place);
170
+
171
+ // Create channel variable
172
+ scope.Var (" Channel" );
173
+
174
+ // Create Variables, x0 will be put into channel,
175
+ // result will be pulled from channel
176
+ CreateVariable (scope, place, " Status" , false );
177
+ CreateVariable (scope, place, " x0" , 99 );
178
+ CreateVariable (scope, place, " result" , 0 );
82
179
83
180
framework::Executor executor (place);
84
181
ProgramDesc program;
@@ -118,5 +215,78 @@ TEST(Concurrency, Go_Op) {
118
215
auto *finalData = tensor.data <int >();
119
216
EXPECT_EQ (finalData[0 ], 99 );
120
217
}
218
+
219
+ /* *
220
+ * This test implements the fibonacci function using go_op and select_op
221
+ */
222
+ TEST (Concurrency, Select) {
223
+ Scope scope;
224
+ p::CPUPlace place;
225
+
226
+ // Initialize scope variables
227
+ p::CPUDeviceContext ctx (place);
228
+
229
+ CreateVariable (scope, place, " Status" , false );
230
+ CreateVariable (scope, place, " result" , 0 );
231
+ CreateVariable (scope, place, " currentXFib" , 0 );
232
+
233
+ framework::Executor executor (place);
234
+ ProgramDesc program;
235
+ BlockDesc *block = program.MutableBlock (0 );
236
+
237
+ // Create channel OP
238
+ std::string dataChanName = " Channel" ;
239
+ scope.Var (dataChanName);
240
+ AddOp (" channel_create" , {}, {{" Out" , {dataChanName}}},
241
+ {{" capacity" , 0 }, {" data_type" , f::proto::VarType::LOD_TENSOR}}, block);
242
+
243
+ std::string quitChanName = " Quit" ;
244
+ scope.Var (quitChanName);
245
+ AddOp (" channel_create" , {}, {{" Out" , {quitChanName}}},
246
+ {{" capacity" , 0 }, {" data_type" , f::proto::VarType::LOD_TENSOR}}, block);
247
+
248
+ // Create Go Op routine, which loops 10 times over fibonacci sequence
249
+ CreateVariable (scope, place, " xReceiveVar" , 0 );
250
+
251
+ BlockDesc *goOpBlock = program.AppendBlock (program.Block (0 ));
252
+ for (int i = 0 ; i < 10 ; ++i) {
253
+ AddOp (" channel_recv" , {{" Channel" , {dataChanName}}},
254
+ {{" Status" , {" Status" }}, {" Out" , {" currentXFib" }}}, {}, goOpBlock);
255
+ AddOp (" print" , {{" In" , {" currentXFib" }}}, {{" Out" , {" currentXFib" }}},
256
+ {{" first_n" , 100 },
257
+ {" summarize" , -1 },
258
+ {" print_tensor_name" , false },
259
+ {" print_tensor_type" , true },
260
+ {" print_tensor_shape" , false },
261
+ {" print_tensor_lod" , false },
262
+ {" print_phase" , std::string (" FORWARD" )},
263
+ {" message" , std::string (" X: " )}},
264
+ goOpBlock);
265
+ }
266
+
267
+ CreateVariable (scope, place, " quitSignal" , 0 );
268
+ AddOp (" channel_send" , {{" Channel" , {quitChanName}}, {" X" , {" quitSignal" }}},
269
+ {{" Status" , {" Status" }}}, {}, goOpBlock);
270
+
271
+ // Create Go Op
272
+ AddOp (" go" , {{" X" , {dataChanName, quitChanName}}}, {},
273
+ {{" sub_block" , goOpBlock}}, block);
274
+
275
+ AddFibonacciSelect (&scope, &place, &program, block, dataChanName,
276
+ quitChanName);
277
+
278
+ // Create Channel Close Op
279
+ AddOp (" channel_close" , {{" Channel" , {dataChanName}}}, {}, {}, block);
280
+ AddOp (" channel_close" , {{" Channel" , {quitChanName}}}, {}, {}, block);
281
+
282
+ executor.Run (program, &scope, 0 , true , true );
283
+
284
+ // After we call executor.run, "result" variable should be equal to 34
285
+ // (which is 10 loops through fibonacci sequence)
286
+ const LoDTensor &tensor = (scope.FindVar (" currentXFib" ))->Get <LoDTensor>();
287
+ auto *finalData = tensor.data <int >();
288
+ EXPECT_EQ (finalData[0 ], 34 );
289
+ }
290
+
121
291
} // namespace framework
122
292
} // namespace paddle
0 commit comments