Skip to content

Commit f9d8c84

Browse files
committed
Factor multiprocessing iterator code to base
1 parent acc0c60 commit f9d8c84

File tree

12 files changed

+398
-678
lines changed

12 files changed

+398
-678
lines changed

.isort.cfg

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
[settings]
2-
force_grid_wrap=0
32
force_single_line=True
43
forced_separate=d1_common,d1_client,django,d1_gmn,d1_test
5-
include_trailing_comma=True
6-
line_length=88
4+
# Compatible with Black
75
multi_line_output=3
6+
include_trailing_comma=True
7+
force_grid_wrap=0
88
use_parentheses=True
9+
line_length=88

.travis.yml

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ dist:
55
sudo:
66
required
77
python:
8-
- 3.6
8+
- "3.6"
99
services:
1010
- postgresql
1111
addons:
@@ -48,10 +48,6 @@ script:
4848
- pytest --cov=. --cov-config .coveragerc --cov-report=term --cov-report=xml -n auto
4949
# --cov-config=tox.ini
5050
# --cov-config=coverage.cfg --cov-report=term --cov-report=xml
51-
# -n auto
5251
after_success:
5352
# Submit results to Coveralls.io.
54-
## Coveralls has a requirement for requests >= 1.0.0, so we install it after
55-
## our packages to prevent it from pulling in the latest version, which is
56-
## likely to conflict with the fixed version our packages pull in.
5753
- coveralls

conftest.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -428,6 +428,15 @@ def enable_db_access(db):
428428
pass
429429

430430

431+
@pytest.fixture(scope='function')
432+
def profile_sql(db):
433+
django.db.connection.queries = []
434+
yield
435+
logging.info('SQL queries by all methods:')
436+
list(map(logging.info, django.db.connection.queries))
437+
438+
439+
431440
@pytest.yield_fixture(scope='session', autouse=True)
432441
@pytest.mark.django_db
433442
def django_db_setup(request, django_db_blocker):

gmn/src/d1_gmn/tests/gmn_test_case.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -63,18 +63,6 @@
6363

6464

6565
class GMNTestCase(d1_test.d1_test_case.D1TestCase):
66-
def setup_class(self):
67-
"""Run for each test class that derives from GMNTestCase."""
68-
if ENABLE_SQL_PROFILING:
69-
django.db.connection.queries = []
70-
71-
def teardown_class(self):
72-
"""Run for each test class that derives from GMNTestCase."""
73-
GMNTestCase.capture_exception()
74-
if ENABLE_SQL_PROFILING:
75-
logging.info('SQL queries by all methods:')
76-
list(map(logging.info, django.db.connection.queries))
77-
7866
def setup_method(self, method):
7967
"""Run for each test method that derives from GMNTestCase."""
8068
# logging.error('GMNTestCase.setup_method()')
Lines changed: 211 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,211 @@
1+
# This work was created by participants in the DataONE project, and is
2+
# jointly copyrighted by participating institutions in DataONE. For
3+
# more information on DataONE, see our web site at http://dataone.org.
4+
#
5+
# Copyright 2009-2019 DataONE
6+
#
7+
# Licensed under the Apache License, Version 2.0 (the "License");
8+
# you may not use this file except in compliance with the License.
9+
# You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing, software
14+
# distributed under the License is distributed on an "AS IS" BASIS,
15+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
# See the License for the specific language governing permissions and
17+
# limitations under the License.
18+
"""Base for multiprocessed DataONE type iterator."""
19+
20+
import logging
21+
import multiprocessing
22+
import time
23+
24+
import d1_common.types.exceptions
25+
26+
import d1_client.mnclient_1_2
27+
import d1_client.mnclient_2_0
28+
29+
30+
# Defaults
31+
PAGE_SIZE = 1000
32+
MAX_WORKERS = 16
33+
# See notes in module docstring for SysMeta iterator before changing
34+
MAX_RESULT_QUEUE_SIZE = 100
35+
MAX_TASK_QUEUE_SIZE = 16
36+
API_MAJOR = 2
37+
38+
39+
logger = logging.getLogger(__name__)
40+
41+
# fmt: off
42+
class MultiprocessedIteratorBase(object):
43+
def __init__(
44+
self,
45+
base_url, page_size, max_workers, max_result_queue_size,
46+
max_task_queue_size, api_major, client_arg_dict, page_arg_dict,
47+
item_proc_arg_dict, page_func, iter_func, item_proc_func,
48+
):
49+
self._base_url = base_url
50+
self._page_size = page_size
51+
self._max_workers = max_workers
52+
self._max_result_queue_size = max_result_queue_size
53+
self._max_task_queue_size = max_task_queue_size
54+
self._api_major = api_major
55+
self._client_arg_dict = client_arg_dict or {}
56+
self._page_arg_dict = page_arg_dict or {}
57+
self._item_proc_arg_dict = item_proc_arg_dict or {}
58+
self._page_func = page_func
59+
self._iter_func = iter_func
60+
self._item_proc_func = item_proc_func
61+
self._total = None
62+
63+
@property
64+
def total(self):
65+
if self._total is None:
66+
client = create_client(
67+
self._base_url, self._api_major, self._client_arg_dict
68+
)
69+
page_pyxb = self._page_func(client)(
70+
start=0, count=0, **self._page_arg_dict
71+
)
72+
self._total = page_pyxb.total
73+
return self._total
74+
75+
def __iter__(self):
76+
manager = multiprocessing.Manager()
77+
queue = manager.Queue(maxsize=self._max_result_queue_size)
78+
namespace = manager.Namespace()
79+
namespace.stop = False
80+
81+
process = multiprocessing.Process(
82+
target=_get_all_pages,
83+
args=(
84+
queue, namespace, self._base_url, self._page_size, self._max_workers,
85+
self._max_task_queue_size, self._api_major, self._client_arg_dict,
86+
self._page_arg_dict, self._item_proc_arg_dict, self._page_func,
87+
self._iter_func, self._item_proc_func, self.total
88+
),
89+
)
90+
91+
process.start()
92+
93+
try:
94+
while True:
95+
item_result = queue.get()
96+
if item_result is None:
97+
logger.debug("Received None sentinel value. Stopping iteration")
98+
break
99+
elif isinstance(item_result, dict):
100+
logger.debug('Raising exception received as dict. dict="{}"'.format(item_result))
101+
yield d1_common.types.exceptions.create_exception_by_name(
102+
item_result["error"],
103+
identifier=item_result["pid"],
104+
)
105+
else:
106+
yield item_result
107+
except GeneratorExit:
108+
logger.debug("GeneratorExit exception")
109+
pass
110+
111+
# If generator is exited before exhausted, provide clean shutdown of the
112+
# generator by signaling processes to stop, then waiting for them.
113+
logger.debug("Setting stop signal")
114+
namespace.stop = True
115+
# Prevent parent from leaving zombie children behind.
116+
while queue.qsize():
117+
logger.debug("Dropping unwanted result")
118+
queue.get()
119+
logger.debug("Waiting for process to exit")
120+
process.join()
121+
122+
123+
def _get_all_pages(
124+
queue, namespace, base_url, page_size, max_workers, max_task_queue_size, api_major,
125+
client_arg_dict, page_arg_dict, item_proc_arg_dict, page_func, iter_func, item_proc_func, n_total
126+
):
127+
logger.debug("Creating pool of {} workers".format(max_workers))
128+
pool = multiprocessing.Pool(processes=max_workers)
129+
n_pages = (n_total - 1) // page_size + 1
130+
131+
for page_idx in range(n_pages):
132+
if namespace.stop:
133+
logger.debug("Received stop signal")
134+
break
135+
try:
136+
pool.apply_async(
137+
_get_page,
138+
args=(
139+
queue, namespace, base_url, page_idx, n_pages, page_size, api_major,
140+
client_arg_dict, page_arg_dict, item_proc_arg_dict, page_func,
141+
iter_func, item_proc_func
142+
),
143+
)
144+
except Exception as e:
145+
logger.debug('Continuing after exception. error="{}"'.format(str(e)))
146+
# The pool does not support a clean way to limit the number of queued tasks
147+
# so we have to access the internals to check the queue size and wait if
148+
# necessary.
149+
while pool._taskqueue.qsize() > max_task_queue_size:
150+
if namespace.stop:
151+
logger.debug("Received stop signal")
152+
break
153+
# logger.debug('_get_all_pages(): Waiting to queue task')
154+
time.sleep(1)
155+
156+
# Workaround for workers hanging at exit.
157+
# pool.terminate()
158+
logger.debug("Preventing more tasks for being added to the pool")
159+
pool.close()
160+
logger.debug("Waiting for the workers to exit")
161+
pool.join()
162+
logger.debug("Sending None sentinel value to stop the generator")
163+
queue.put(None)
164+
165+
166+
def _get_page(
167+
queue, namespace, base_url, page_idx, n_pages, page_size, api_major,
168+
client_arg_dict, page_arg_dict, item_proc_arg_dict, page_func, iter_func, item_proc_func
169+
):
170+
logger.debug("Processing page. page_idx={} n_pages={}".format(page_idx, n_pages))
171+
172+
if namespace.stop:
173+
logger.debug("Received stop signal")
174+
return
175+
176+
client = create_client(base_url, api_major, client_arg_dict)
177+
178+
try:
179+
page_pyxb = page_func(client)(
180+
start=page_idx * page_size, count=page_size, **page_arg_dict
181+
)
182+
except Exception as e:
183+
logger.error(
184+
'Unable to get page. page_idx={} page_total={} error="{}"'.format(
185+
page_idx, n_pages, str(e)
186+
)
187+
)
188+
return
189+
190+
iterable_pyxb = iter_func(page_pyxb)
191+
192+
logger.debug(
193+
"Starting page item iteration. page_idx={} n_items={}".format(
194+
page_idx, len(iterable_pyxb)
195+
)
196+
)
197+
198+
for item_pyxb in iterable_pyxb:
199+
if namespace.stop:
200+
logger.debug("Received stop signal")
201+
break
202+
queue.put(item_proc_func(client, item_pyxb, item_proc_arg_dict))
203+
204+
logger.debug("Completed page")
205+
206+
207+
def create_client(base_url, api_major, client_arg_dict):
208+
if api_major in (1, "1", "v1"):
209+
return d1_client.mnclient_1_2.MemberNodeClient_1_2(base_url, **client_arg_dict)
210+
else:
211+
return d1_client.mnclient_2_0.MemberNodeClient_2_0(base_url, **client_arg_dict)

0 commit comments

Comments
 (0)