Skip to content

Commit 2935865

Browse files
authored
Make worker.get_full_worker_args more user friendly (#223)
We can just accept `**kwargs` - this allows usage like `worker.get_full_worker_args(some_type, arg1=val1, arg2=val2)`
1 parent f89c355 commit 2935865

File tree

3 files changed

+44
-2
lines changed

3 files changed

+44
-2
lines changed

compiler_opt/distributed/local/local_worker_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,7 @@ def create_local_worker_pool(worker_cls: 'type[worker.Worker]',
244244
"""Create a local worker pool for worker_cls."""
245245
if not count:
246246
count = multiprocessing.get_context().cpu_count()
247-
final_kwargs = worker.get_full_worker_args(worker_cls, kwargs)
247+
final_kwargs = worker.get_full_worker_args(worker_cls, **kwargs)
248248
stubs = [_make_stub(worker_cls, *args, **final_kwargs) for _ in range(count)]
249249
return worker.FixedWorkerPool(workers=stubs, worker_concurrency=16)
250250

compiler_opt/distributed/worker.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def get_exception(worker_future: WorkerFuture) -> Optional[Exception]:
9191
return e
9292

9393

94-
def get_full_worker_args(worker_class: 'type[Worker]', current_kwargs):
94+
def get_full_worker_args(worker_class: 'type[Worker]', **current_kwargs):
9595
"""Get the union of given kwargs and gin config.
9696
9797
This allows the worker hosting process be set up differently from the training
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# coding=utf-8
2+
# Copyright 2020 Google LLC
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""Test for worker."""
16+
17+
import gin
18+
19+
from absl.testing import absltest
20+
from compiler_opt.distributed import worker
21+
22+
23+
@gin.configurable(module='_test')
24+
class SomeType:
25+
26+
def __init__(self, argument):
27+
pass
28+
29+
30+
class WorkerTest(absltest.TestCase):
31+
32+
def test_gin_args(self):
33+
with gin.unlock_config():
34+
gin.bind_parameter('_test.SomeType.argument', 42)
35+
real_args = worker.get_full_worker_args(
36+
SomeType, more_args=2, even_more_args='hi')
37+
self.assertDictEqual(real_args,
38+
dict(argument=42, more_args=2, even_more_args='hi'))
39+
40+
41+
if __name__ == '__main__':
42+
absltest.main()

0 commit comments

Comments
 (0)