Skip to content

Commit 26a9208

Browse files
Varun Aroraabhinavarora
authored andcommitted
New PingPong test for testing channels / concurrency (#9132)
* New test for testing channels / concurrency * Formatting fix
1 parent e382e42 commit 26a9208

File tree

1 file changed

+51
-0
lines changed

1 file changed

+51
-0
lines changed

python/paddle/fluid/tests/test_concurrency.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,57 @@ def fibonacci(channel, quit_channel):
217217
exe_result = exe.run(fetch_list=[result])
218218
self.assertEqual(exe_result[0][0], 34)
219219

220+
def test_ping_pong(self):
221+
"""
222+
Mimics Ping Pong example: https://gobyexample.com/channel-directions
223+
"""
224+
with framework.program_guard(framework.Program()):
225+
result = self._create_tensor('return_value',
226+
core.VarDesc.VarType.LOD_TENSOR,
227+
core.VarDesc.VarType.FP64)
228+
229+
ping_result = self._create_tensor('ping_return_value',
230+
core.VarDesc.VarType.LOD_TENSOR,
231+
core.VarDesc.VarType.FP64)
232+
233+
pong_result = self._create_tensor('pong_return_value',
234+
core.VarDesc.VarType.LOD_TENSOR,
235+
core.VarDesc.VarType.FP64)
236+
237+
def ping(ch, message):
238+
message_to_send_tmp = fill_constant(
239+
shape=[1], dtype=core.VarDesc.VarType.FP64, value=0)
240+
241+
assign(input=message, output=message_to_send_tmp)
242+
fluid.channel_send(ch, message_to_send_tmp)
243+
244+
def pong(ch1, ch2):
245+
fluid.channel_recv(ch1, ping_result)
246+
assign(input=ping_result, output=pong_result)
247+
fluid.channel_send(ch2, pong_result)
248+
249+
pings = fluid.make_channel(
250+
dtype=core.VarDesc.VarType.LOD_TENSOR, capacity=1)
251+
pongs = fluid.make_channel(
252+
dtype=core.VarDesc.VarType.LOD_TENSOR, capacity=1)
253+
254+
msg = fill_constant(
255+
shape=[1], dtype=core.VarDesc.VarType.FP64, value=9)
256+
257+
ping(pings, msg)
258+
pong(pings, pongs)
259+
260+
fluid.channel_recv(pongs, result)
261+
262+
fluid.channel_close(pings)
263+
fluid.channel_close(pongs)
264+
265+
cpu = core.CPUPlace()
266+
exe = Executor(cpu)
267+
268+
exe_result = exe.run(fetch_list=[result])
269+
self.assertEqual(exe_result[0][0], 9)
270+
220271

221272
if __name__ == '__main__':
222273
unittest.main()

0 commit comments

Comments
 (0)