Skip to content

Commit c564af9

Browse files
author
zenghsh3
authored
fix unittest (#693)
* fix unittest * fix unittest * fix unittest
1 parent 48d7c78 commit c564af9

File tree

2 files changed

+98
-32
lines changed

2 files changed

+98
-32
lines changed
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
import parl
17+
from parl.remote.master import Master
18+
from parl.remote.worker import Worker
19+
import time
20+
import threading
21+
from parl.remote.client import disconnect
22+
from parl.remote import exceptions
23+
import subprocess
24+
from parl.utils import logger
25+
from parl.utils import get_free_tcp_port
26+
from unittest import mock
27+
28+
29+
@parl.remote_class
30+
class Actor(object):
31+
def __init__(self, arg1=None, arg2=None):
32+
self.arg1 = arg1
33+
self.arg2 = arg2
34+
35+
def get_arg1(self):
36+
return self.arg1
37+
38+
def get_arg2(self):
39+
return self.arg2
40+
41+
def set_arg1(self, value):
42+
self.arg1 = value
43+
44+
def set_arg2(self, value):
45+
self.arg2 = value
46+
47+
def add_one(self, value):
48+
value += 1
49+
return value
50+
51+
def add(self, x, y):
52+
time.sleep(3)
53+
return x + y
54+
55+
def will_raise_exception_func(self):
56+
x = 1 / 0
57+
58+
59+
class TestCluster(unittest.TestCase):
60+
def tearDown(self):
61+
disconnect()
62+
time.sleep(60) # wait for test case finishing
63+
64+
def test_actor_exception_2(self):
65+
return_true = mock.Mock(return_value=True)
66+
with mock.patch(
67+
'parl.remote.remote_class_serialization.is_implemented_in_notebook',
68+
return_true):
69+
port = get_free_tcp_port()
70+
logger.info("running: test_actor_exception_2")
71+
master = Master(port=port)
72+
th = threading.Thread(target=master.run)
73+
th.start()
74+
time.sleep(3)
75+
worker1 = Worker('localhost:{}'.format(port), 1)
76+
self.assertEqual(1, master.cpu_num)
77+
parl.connect('localhost:{}'.format(port))
78+
79+
actor = Actor()
80+
81+
with self.assertRaises(exceptions.RemoteError):
82+
actor.will_raise_exception_func()
83+
84+
actor2 = Actor()
85+
for _ in range(5):
86+
if master.cpu_num == 0:
87+
break
88+
time.sleep(10)
89+
self.assertEqual(actor2.add_one(1), 2)
90+
self.assertEqual(0, master.cpu_num)
91+
del actor
92+
del actor2
93+
worker1.exit()
94+
master.exit()
95+
96+
97+
if __name__ == '__main__':
98+
unittest.main()

parl/remote/tests/cluster_notebook_test.py

Lines changed: 0 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -97,38 +97,6 @@ def test_actor_exception(self):
9797
master.exit()
9898
worker1.exit()
9999

100-
def test_actor_exception_2(self):
101-
return_true = mock.Mock(return_value=True)
102-
with mock.patch(
103-
'parl.remote.remote_class_serialization.is_implemented_in_notebook',
104-
return_true):
105-
port = get_free_tcp_port()
106-
logger.info("running: test_actor_exception_2")
107-
master = Master(port=port)
108-
th = threading.Thread(target=master.run)
109-
th.start()
110-
time.sleep(3)
111-
worker1 = Worker('localhost:{}'.format(port), 1)
112-
self.assertEqual(1, master.cpu_num)
113-
parl.connect('localhost:{}'.format(port))
114-
115-
actor = Actor()
116-
117-
with self.assertRaises(exceptions.RemoteError):
118-
actor.will_raise_exception_func()
119-
120-
actor2 = Actor()
121-
for _ in range(5):
122-
if master.cpu_num == 0:
123-
break
124-
time.sleep(10)
125-
self.assertEqual(actor2.add_one(1), 2)
126-
self.assertEqual(0, master.cpu_num)
127-
del actor
128-
del actor2
129-
worker1.exit()
130-
master.exit()
131-
132100

133101
if __name__ == '__main__':
134102
unittest.main()

0 commit comments

Comments
 (0)