Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions distarray/__version__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,5 @@
Version information for the DistArray package.
"""

__short_version__ = "0.5"
__version__ = "0.5.0"
__short_version__ = "0.6"
__version__ = "0.6.0-dev"
10 changes: 8 additions & 2 deletions distarray/globalapi/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,12 @@ def apply(self, func, args=None, kwargs=None, targets=None):
def push_function(self, key, func):
pass

def __enter__(self):
return self

def __exit__(self, type_, value, traceback):
self.close()

def _setup_context_key(self):
"""
Create a dict on the engines which will hold everything from
Expand Down Expand Up @@ -205,7 +211,7 @@ def local_allclose(la, lb, rtol, atol):
from numpy import allclose
return allclose(la.ndarray, lb.ndarray, rtol, atol)

local_results = self.apply(local_allclose,
local_results = self.apply(local_allclose,
(a.key, b.key, rtol, atol),
targets=a.targets)
return all(local_results)
Expand Down Expand Up @@ -579,7 +585,7 @@ def is_NoneType(pxy):
return pxy.type_str == str(type(None))

def is_LocalArray(pxy):
return (isinstance(pxy, Proxy) and
return (isinstance(pxy, Proxy) and
pxy.type_str == "<class 'distarray.localapi.localarray.LocalArray'>")

if all(is_LocalArray(r) for r in results):
Expand Down
19 changes: 15 additions & 4 deletions distarray/globalapi/tests/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

Many of these tests require a 4-engine cluster to be running locally. The
engines should be launched with MPI, using the MPIEngineSetLauncher.

"""

import unittest
Expand All @@ -19,13 +18,25 @@

from numpy.testing import assert_allclose, assert_array_equal

from distarray.testing import DefaultContextTestCase, IPythonContextTestCase, check_targets
from distarray.testing import (DefaultContextTestCase, IPythonContextTestCase,
check_targets)
from distarray.globalapi.context import Context
from distarray.globalapi.maps import Distribution
from distarray.mpionly_utils import is_solo_mpi_process, get_nengines
from distarray.localapi import LocalArray


class TestContextManager(DefaultContextTestCase):

ntargets = 'any'

def test_manager(self):
with Context() as mycon:
testarr = mycon.zeros((10,10))
# `close` is currently a no-op for MPI contexts, so I don't test
# anything regarding the __exit__ behavior


class TestRegister(DefaultContextTestCase):

ntargets = 'any'
Expand Down Expand Up @@ -53,7 +64,7 @@ def test_local_sin(self):
def local_sin(da):
return numpy.sin(da)
self.context.register(local_sin)

db = self.context.local_sin(self.da)
assert_allclose(0, db.tondarray(), atol=1e-14)

Expand Down Expand Up @@ -146,7 +157,7 @@ def local_none(da):
self.assertTrue(dp is None)

def test_parameterless(self):

def parameterless():
"""This is a parameterless function."""
return None
Expand Down