Skip to content

Commit 5b0a7b7

Browse files
Minor cleanup in datatypes (#438)
* Made communicator class attribute of `mesh` * Moved to communicator as class attribute of `Tensor` as well * Forgot to add file to commit. * Fixes * Fixed environment
1 parent 86dfb23 commit 5b0a7b7

File tree

4 files changed

+132
-75
lines changed

4 files changed

+132
-75
lines changed

etc/environment-pytorch.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,5 @@ dependencies:
1111
- pytorch
1212
- matplotlib>=3.0
1313
- dill
14+
- mpich
15+
- mpi4py>=3.0.0

pySDC/implementations/datatype_classes/mesh.py

Lines changed: 8 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
import numpy as np
22

3-
from pySDC.core.Errors import DataError
4-
53
try:
64
# TODO : mpi4py cannot be imported before dolfin when using fenics mesh
75
# see https://github.com/Parallel-in-Time/pySDC/pull/285#discussion_r1145850590
@@ -17,10 +15,12 @@ class mesh(np.ndarray):
1715
Can include a communicator and expects a dtype to allow complex data.
1816
1917
Attributes:
20-
_comm: MPI communicator or None
18+
comm: MPI communicator or None
2119
"""
2220

23-
def __new__(cls, init, val=0.0, offset=0, buffer=None, strides=None, order=None):
21+
comm = None
22+
23+
def __new__(cls, init, val=0.0, **kwargs):
2424
"""
2525
Instantiates new datatype. This ensures that even when manipulating data, the result is still a mesh.
2626
@@ -33,56 +33,32 @@ def __new__(cls, init, val=0.0, offset=0, buffer=None, strides=None, order=None)
3333
3434
"""
3535
if isinstance(init, mesh):
36-
obj = np.ndarray.__new__(
37-
cls, shape=init.shape, dtype=init.dtype, buffer=buffer, offset=offset, strides=strides, order=order
38-
)
36+
obj = np.ndarray.__new__(cls, shape=init.shape, dtype=init.dtype, **kwargs)
3937
obj[:] = init[:]
40-
obj._comm = init._comm
4138
elif (
4239
isinstance(init, tuple)
4340
and (init[1] is None or isinstance(init[1], MPI.Intracomm))
4441
and isinstance(init[2], np.dtype)
4542
):
46-
obj = np.ndarray.__new__(
47-
cls, init[0], dtype=init[2], buffer=buffer, offset=offset, strides=strides, order=order
48-
)
43+
obj = np.ndarray.__new__(cls, init[0], dtype=init[2], **kwargs)
4944
obj.fill(val)
50-
obj._comm = init[1]
45+
cls.comm = init[1]
5146
else:
5247
raise NotImplementedError(type(init))
5348
return obj
5449

55-
@property
56-
def comm(self):
57-
"""
58-
Getter for the communicator
59-
"""
60-
return self._comm
61-
62-
def __array_finalize__(self, obj):
63-
"""
64-
Finalizing the datatype. Without this, new datatypes do not 'inherit' the communicator.
65-
"""
66-
if obj is None:
67-
return
68-
self._comm = getattr(obj, '_comm', None)
69-
7050
def __array_ufunc__(self, ufunc, method, *inputs, out=None, **kwargs):
7151
"""
7252
Overriding default ufunc, cf. https://numpy.org/doc/stable/user/basics.subclassing.html#array-ufunc-for-ufuncs
7353
"""
7454
args = []
75-
comm = None
7655
for _, input_ in enumerate(inputs):
7756
if isinstance(input_, mesh):
7857
args.append(input_.view(np.ndarray))
79-
comm = input_.comm
8058
else:
8159
args.append(input_)
8260

83-
results = super(mesh, self).__array_ufunc__(ufunc, method, *args, **kwargs).view(type(self))
84-
if type(self) == type(results):
85-
results._comm = comm
61+
results = super().__array_ufunc__(ufunc, method, *args, **kwargs).view(type(self))
8662
return results
8763

8864
def __abs__(self):

pySDC/playgrounds/ML_initial_guess/tensor.py

Lines changed: 15 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,5 @@
1-
import numpy as np
21
import torch
32

4-
from pySDC.core.Errors import DataError
5-
63
try:
74
from mpi4py import MPI
85
except ImportError:
@@ -12,14 +9,17 @@
129
class Tensor(torch.Tensor):
1310
"""
1411
Wrapper for PyTorch tensor.
15-
Be aware that this is totally WIP! Should be fine to count iterations, but desperately needs cleaning up if this project goes much further!
12+
Be aware that this is totally WIP! Should be fine to count iterations, but desperately needs cleaning up if this
13+
project goes much further!
1614
1715
TODO: Have to update `torch/multiprocessing/reductions.py` in order to share this datatype across processes.
1816
1917
Attributes:
20-
_comm: MPI communicator or None
18+
comm: MPI communicator or None
2119
"""
2220

21+
comm = None
22+
2323
@staticmethod
2424
def __new__(cls, init, val=0.0, *args, **kwargs):
2525
"""
@@ -33,54 +33,26 @@ def __new__(cls, init, val=0.0, *args, **kwargs):
3333
obj of type mesh
3434
3535
"""
36-
if isinstance(init, Tensor):
37-
obj = super().__new__(cls, init)
36+
# TODO: The cloning of tensors going in is likely slow
37+
38+
if isinstance(init, torch.Tensor):
39+
obj = super().__new__(cls, init.clone())
3840
obj[:] = init[:]
39-
obj._comm = init._comm
4041
elif (
4142
isinstance(init, tuple)
42-
# and (init[1] is None or isinstance(init[1], MPI.Intracomm))
43+
and (init[1] is None or isinstance(init[1], MPI.Intracomm))
4344
# and isinstance(init[2], np.dtype)
4445
):
45-
obj = super().__new__(cls, init[0].clone())
46+
if isinstance(init[0][0], torch.Tensor):
47+
obj = super().__new__(cls, init[0].clone())
48+
else:
49+
obj = super().__new__(cls, *init[0])
4650
obj.fill_(val)
47-
obj._comm = init[1]
51+
cls.comm = init[1]
4852
else:
4953
raise NotImplementedError(type(init))
5054
return obj
5155

52-
def __add__(self, *args, **kwargs):
53-
res = super().__add__(*args, **kwargs)
54-
res._comm = self.comm
55-
return res
56-
57-
def __sub__(self, *args, **kwargs):
58-
res = super().__sub__(*args, **kwargs)
59-
res._comm = self.comm
60-
return res
61-
62-
def __lmul__(self, *args, **kwargs):
63-
res = super().__lmul__(*args, **kwargs)
64-
res._comm = self.comm
65-
return res
66-
67-
def __rmul__(self, *args, **kwargs):
68-
res = super().__rmul__(*args, **kwargs)
69-
res._comm = self.comm
70-
return res
71-
72-
def __mul__(self, *args, **kwargs):
73-
res = super().__mul__(*args, **kwargs)
74-
res._comm = self.comm
75-
return res
76-
77-
@property
78-
def comm(self):
79-
"""
80-
Getter for the communicator
81-
"""
82-
return self._comm
83-
8456
def __abs__(self):
8557
"""
8658
Overloading the abs operator
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
import pytest
2+
3+
4+
def get_dtype(name):
5+
if name == 'Tensor':
6+
from pySDC.playgrounds.ML_initial_guess.tensor import Tensor as dtype_cls
7+
elif name in ['mesh', 'imex_mesh']:
8+
import pySDC.implementations.datatype_classes.mesh as mesh
9+
10+
dtype_cls = eval(f'mesh.{name}')
11+
else:
12+
raise NotImplementedError(f'Don\'t know a dtype of name {name!r}!')
13+
14+
return dtype_cls
15+
16+
17+
def single_test(name, useMPI=False):
18+
"""
19+
This test checks that the communicator and datatype are maintained when generating new instances.
20+
Also, it makes sure that you can supply different communicators.
21+
"""
22+
import numpy as np
23+
24+
dtype_cls = get_dtype(name)
25+
26+
shape = (5,)
27+
comm = None
28+
dtype = np.dtype('f')
29+
30+
if useMPI:
31+
from mpi4py import MPI
32+
33+
comm_wd = MPI.COMM_WORLD
34+
comm = comm_wd.Split(comm_wd.rank < comm_wd.size - 1)
35+
36+
expected_rank = comm_wd.rank % (comm_wd.size - 1)
37+
38+
init = (shape, comm, dtype)
39+
40+
a = dtype_cls(init, val=1.0)
41+
b = dtype_cls(init, val=99.0)
42+
c = dtype_cls(a)
43+
d = a + b
44+
45+
for me in [a, b, c, d]:
46+
assert type(me) == dtype_cls
47+
assert me.comm == comm
48+
49+
if hasattr(me, 'shape') and not hasattr(me, 'components'):
50+
assert me.shape == shape
51+
52+
if useMPI:
53+
assert comm.rank == expected_rank
54+
assert comm.size < comm_wd.size
55+
56+
57+
def launch_test(name, useMPI, num_procs=1):
58+
if useMPI:
59+
import os
60+
import subprocess
61+
62+
# Set python path once
63+
my_env = os.environ.copy()
64+
my_env['PYTHONPATH'] = '../../..:.'
65+
my_env['COVERAGE_PROCESS_START'] = 'pyproject.toml'
66+
67+
cmd = f"mpirun -np {num_procs} python {__file__} --name={name} --useMPI=True"
68+
69+
p = subprocess.Popen(cmd.split(), env=my_env, cwd=".")
70+
71+
p.wait()
72+
assert p.returncode == 0, 'ERROR: did not get return code 0, got %s with %2i processes' % (
73+
p.returncode,
74+
num_procs,
75+
)
76+
else:
77+
single_test(name, False)
78+
79+
80+
@pytest.mark.pytorch
81+
@pytest.mark.parametrize('useMPI', [True, False])
82+
def test_PyTorch_dtype(useMPI):
83+
launch_test('Tensor', useMPI=useMPI, num_procs=4)
84+
85+
86+
@pytest.mark.mpi4py
87+
@pytest.mark.parametrize('name', ['mesh', 'imex_mesh'])
88+
def test_mesh_dtypes_MPI(name):
89+
launch_test(name, useMPI=True, num_procs=4)
90+
91+
92+
@pytest.mark.base
93+
@pytest.mark.parametrize('name', ['mesh', 'imex_mesh'])
94+
def test_mesh_dtypes(name):
95+
launch_test(name, useMPI=False)
96+
97+
98+
if __name__ == '__main__':
99+
str_to_bool = lambda me: False if me == 'False' else True
100+
import argparse
101+
102+
parser = argparse.ArgumentParser()
103+
parser.add_argument('--name', type=str, help='Name of the datatype')
104+
parser.add_argument('--useMPI', type=str_to_bool, help='Toggle for MPI', choices=[True, False])
105+
args = parser.parse_args()
106+
107+
single_test(**vars(args))

0 commit comments

Comments
 (0)