Skip to content

Commit 9ed2fba

Browse files
committed
Merge pull request #585 from enthought/bugfix/apply-extract-da-key
Make Context.apply() work automatically with DAs.
2 parents ccaef6d + 0889a45 commit 9ed2fba

File tree

2 files changed

+23
-2
lines changed

2 files changed

+23
-2
lines changed

distarray/globalapi/context.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def make_subcomm(self, new_targets):
7171
pass
7272

7373
@abstractmethod
74-
def apply(self, func, args=None, kwargs=None, targets=None):
74+
def apply(self, func, args=None, kwargs=None, targets=None, autoproxyize=False):
7575
pass
7676

7777
@abstractmethod
@@ -816,6 +816,10 @@ def func_wrapper(func, apply_nonce, context_key, args, kwargs, autoproxyize):
816816
# default arguments
817817
args = () if args is None else args
818818
kwargs = {} if kwargs is None else kwargs
819+
820+
args = tuple(a.key if isinstance(a, DistArray) else a for a in args)
821+
kwargs = {k: (v.key if isinstance(v, DistArray) else v) for k, v in kwargs.items()}
822+
819823
apply_nonce = nonce()
820824
wrapped_args = (func, apply_nonce, self.context_key, args, kwargs, autoproxyize)
821825

@@ -972,6 +976,10 @@ def apply(self, func, args=None, kwargs=None, targets=None, autoproxyize=False):
972976
# default arguments
973977
args = () if args is None else args
974978
kwargs = {} if kwargs is None else kwargs
979+
980+
args = tuple(a.key if isinstance(a, DistArray) else a for a in args)
981+
kwargs = {k: (v.key if isinstance(v, DistArray) else v) for k, v in kwargs.items()}
982+
975983
targets = self.targets if targets is None else targets
976984

977985
apply_nonce = nonce()

distarray/globalapi/tests/test_context.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -378,7 +378,6 @@ def foo(a, b, c=None, d=None):
378378

379379
self.assertEqual(val, [9] * self.ntargets)
380380

381-
382381
def test_apply_proxy(self):
383382

384383
def foo():
@@ -401,6 +400,20 @@ def foo():
401400
self.assertEqual(set(r[0].name for r in res), set([res[0][0].name]))
402401
self.assertEqual(set(r[-1].name for r in res), set([res[0][-1].name]))
403402

403+
def test_apply_distarray(self):
404+
405+
da = self.context.empty((len(self.context.targets),), dtype=numpy.uint32)
406+
407+
def local_label(la):
408+
la.ndarray.fill(la.comm.rank)
409+
410+
# Testing that we can pass in `da` and `apply()` extracts `da.key` automatically.
411+
self.context.apply(local_label, (da,))
412+
assert_array_equal(da.tondarray(), range(len(self.context.targets)))
413+
414+
self.context.apply(local_label, kwargs={'la': da})
415+
assert_array_equal(da.tondarray(), range(len(self.context.targets)))
416+
404417
class TestGetBaseComm(DefaultContextTestCase):
405418

406419
ntargets = 'any'

0 commit comments

Comments
 (0)