Skip to content

Commit 51630f8

Browse files
committed
Allow passing numpy types to operators
1 parent 442361b commit 51630f8

File tree

8 files changed

+64
-15
lines changed

8 files changed

+64
-15
lines changed

examples/plot_convolve.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
plt.close('all')
2323

2424
###############################################################################
25-
# We will start by creating a zero signal of lenght :math:`nt` and we will
25+
# We will start by creating a zero signal of length :math:`nt` and we will
2626
# place a unitary spike at its center. We also create our filter to be
2727
# applied by means of :py:class:`pylops_gpu.signalprocessing.Convolve1D`
2828
# operator.

pylops_gpu/basicoperators/Diagonal.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from pylops_gpu import LinearOperator
66
from pylops_gpu.utils.complex import conj, flatten, reshape, \
77
complextorch_fromnumpy
8+
from pylops_gpu.utils.torch2numpy import torchtype_from_numpytype
89

910

1011
class Diagonal(LinearOperator):
@@ -78,7 +79,7 @@ def __init__(self, diag, dims=None, dir=0, device='cpu',
7879
self.device = device
7980
self.togpu = togpu
8081
self.tocpu = tocpu
81-
self.dtype = dtype
82+
self.dtype = torchtype_from_numpytype(dtype)
8283
self.explicit = False
8384
self.Op = None
8485

pylops_gpu/basicoperators/FirstDerivative.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from pylops_gpu import LinearOperator
44
from pylops_gpu.signalprocessing import Convolve1D
5+
from pylops_gpu.utils.torch2numpy import torchtype_from_numpytype
56

67

78
"""
@@ -39,7 +40,7 @@ class FirstDerivative(LinearOperator):
3940
tocpu : :obj:`tuple`, optional
4041
Move data and model from gpu to cpu after applying ``matvec`` and
4142
``rmatvec``, respectively (only when ``device='gpu'``)
42-
dtype : :obj:`torch.dtype`, optional
43+
dtype : :obj:`torch.dtype` or :obj:`np.dtype`, optional
4344
Type of elements in input array.
4445
4546
Attributes
@@ -63,6 +64,9 @@ class FirstDerivative(LinearOperator):
6364
def __init__(self, N, dims=None, dir=0, sampling=1., device='cpu',
6465
togpu=(False, False), tocpu=(False, False),
6566
dtype=torch.float32):
67+
# convert dtype to torch.dtype
68+
dtype = torchtype_from_numpytype(dtype)
69+
6670
h = torch.torch.tensor([0.5, 0, -0.5],
6771
dtype=dtype).to(device) / sampling
6872
self.device = device
@@ -73,4 +77,4 @@ def __init__(self, N, dims=None, dir=0, sampling=1., device='cpu',
7377
self.explicit = False
7478
self.Op = Convolve1D(N, h, offset=1, dims=dims, dir=dir,
7579
zero_edges=True, device=device,
76-
togpu=togpu, tocpu=tocpu, dtype=dtype)
80+
togpu=togpu, tocpu=tocpu, dtype=self.dtype)

pylops_gpu/basicoperators/MatrixMult.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
from pytorch_complex_tensor import ComplexTensor
55
from pylops_gpu.LinearOperator import LinearOperator
66
from pylops_gpu.utils.complex import conj, reshape, flatten
7-
from pylops_gpu.utils.torch2numpy import numpytype_from_torchtype
7+
from pylops_gpu.utils.torch2numpy import numpytype_from_torchtype, \
8+
torchtype_from_numpytype
89

910

1011
class MatrixMult(LinearOperator):
@@ -29,7 +30,7 @@ class MatrixMult(LinearOperator):
2930
tocpu : :obj:`tuple`, optional
3031
Move data and model from gpu to cpu after applying ``matvec`` and
3132
``rmatvec``, respectively (only when ``device='gpu'``)
32-
dtype : :obj:`torch.dtype`, optional
33+
dtype : :obj:`torch.dtype` or :obj:`np.dtype`, optional
3334
Type of elements in input array.
3435
3536
Attributes
@@ -49,10 +50,12 @@ class MatrixMult(LinearOperator):
4950
def __init__(self, A, dims=None, device='cpu',
5051
togpu=(False, False), tocpu=(False, False),
5152
dtype=torch.float32):
53+
# convert A to torch tensor if provided as numpy array numpy
5254
if not isinstance(A, (torch.Tensor, ComplexTensor)):
53-
self.complex = True if np.iscomplexobj(A) else False
55+
dtype = numpytype_from_torchtype(dtype)
5456
self.A = \
5557
torch.from_numpy(A.astype(numpytype_from_torchtype(dtype))).to(device)
58+
self.complex = True if np.iscomplexobj(A) else False
5659
else:
5760
self.complex = True if isinstance(A, ComplexTensor) else False
5861
self.A = A
@@ -76,7 +79,7 @@ def __init__(self, A, dims=None, device='cpu',
7679
self.device = device
7780
self.togpu = togpu
7881
self.tocpu = tocpu
79-
self.dtype = dtype
82+
self.dtype = torchtype_from_numpytype(dtype)
8083
self.explicit = True
8184
self.Op = None
8285

pylops_gpu/basicoperators/Restriction.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import numpy as np
33

44
from pylops_gpu.LinearOperator import LinearOperator
5+
from pylops_gpu.utils.torch2numpy import torchtype_from_numpytype
56

67

78
class Restriction(LinearOperator):
@@ -78,7 +79,7 @@ def __init__(self, M, iava, dims=None, dir=0, inplace=True,
7879
self.device = device
7980
self.togpu = togpu
8081
self.tocpu = tocpu
81-
self.dtype = dtype
82+
self.dtype = torchtype_from_numpytype(dtype)
8283
self.explicit = True
8384
self.Op = None
8485

pylops_gpu/basicoperators/SecondDerivative.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from pylops_gpu import LinearOperator
44
from pylops_gpu.signalprocessing import Convolve1D
5+
from pylops_gpu.utils.torch2numpy import torchtype_from_numpytype
56

67

78
class SecondDerivative(LinearOperator):
@@ -28,7 +29,7 @@ class SecondDerivative(LinearOperator):
2829
tocpu : :obj:`tuple`, optional
2930
Move data and model from gpu to cpu after applying ``matvec`` and
3031
``rmatvec``, respectively (only when ``device='gpu'``)
31-
dtype : :obj:`torch.dtype`, optional
32+
dtype : :obj:`torch.dtype` or :obj:`np.dtype`, optional
3233
Type of elements in input array.
3334
3435
Attributes
@@ -52,6 +53,9 @@ class SecondDerivative(LinearOperator):
5253
def __init__(self, N, dims=None, dir=0, sampling=1., device='cpu',
5354
togpu=(False, False), tocpu=(False, False),
5455
dtype=torch.float32):
56+
# convert dtype to torch.dtype
57+
dtype = torchtype_from_numpytype(dtype)
58+
5559
h = torch.torch.tensor([1., -2, 1.],
5660
dtype=dtype).to(device) / sampling**2
5761
self.device = device

pylops_gpu/signalprocessing/Convolve1D.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from torch.nn.functional import pad
55
from pylops_gpu import LinearOperator
6+
from pylops_gpu.utils.torch2numpy import torchtype_from_numpytype
67

78

89
class Convolve1D(LinearOperator):
@@ -16,7 +17,7 @@ class Convolve1D(LinearOperator):
1617
----------
1718
N : :obj:`int`
1819
Number of samples in model.
19-
h : :obj:`torch.Tensor`
20+
h : :obj:`torch.Tensor` or :obj:`numpy.ndarray`
2021
1d compact filter to be convolved to input signal
2122
offset : :obj:`int`
2223
Index of the center of the compact filter
@@ -55,6 +56,13 @@ class Convolve1D(LinearOperator):
5556
def __init__(self, N, h, offset=0, dims=None, dir=0, zero_edges=False,
5657
device='cpu', togpu=(False, False), tocpu=(False, False),
5758
dtype=torch.float32):
59+
# convert dtype to torch.dtype
60+
if not isinstance(dtype, torch.dtype):
61+
dtype = torchtype_from_numpytype(dtype)
62+
63+
# convert h to torch if numpy
64+
if not isinstance(h, torch.Tensor):
65+
h = torch.from_numpy(h).to(device)
5866
self.nh = h.size()[0]
5967
self.h = h.reshape(1, 1, self.nh)
6068
self.offset = 2*(self.nh // 2 - int(offset))

pylops_gpu/utils/torch2numpy.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,25 @@
22
import numpy as np
33

44

5+
6+
def numpytype_from_strtype(strtype):
7+
"""Convert str into equivalent numpy type
8+
9+
Parameters
10+
----------
11+
strtype : :obj:`str`
12+
String type
13+
14+
Returns
15+
-------
16+
numpytype : :obj:`numpy.dtype`
17+
Numpy equivalent type
18+
19+
"""
20+
numpytype = np.dtype(strtype)
21+
return numpytype
22+
23+
524
def numpytype_from_torchtype(torchtype):
625
"""Convert torch type into equivalent numpy type
726
@@ -12,11 +31,15 @@ def numpytype_from_torchtype(torchtype):
1231
1332
Returns
1433
-------
15-
numpytype : :obj:`torch.dtype`
34+
numpytype : :obj:`numpy.dtype`
1635
Numpy equivalent type
1736
1837
"""
19-
numpytype = torch.scalar_tensor(1, dtype=torchtype).numpy().dtype
38+
if isinstance(torchtype, torch.dtype):
39+
numpytype = torch.scalar_tensor(1, dtype=torchtype).numpy().dtype
40+
else:
41+
# in case it is already a numpy dtype
42+
numpytype = torchtype
2043
return numpytype
2144

2245

@@ -25,7 +48,7 @@ def torchtype_from_numpytype(numpytype):
2548
2649
Parameters
2750
----------
28-
numpytype : :obj:`torch.dtype`
51+
numpytype : :obj:`numpy.dtype`
2952
Numpy type
3053
3154
Returns
@@ -40,5 +63,10 @@ def torchtype_from_numpytype(numpytype):
4063
returned.
4164
4265
"""
43-
torchtype = torch.from_numpy(np.real(np.ones(1, dtype=numpytype))).dtype
66+
if isinstance(numpytype, torch.dtype):
67+
# in case it is already a torch dtype
68+
torchtype = numpytype
69+
else:
70+
torchtype = \
71+
torch.from_numpy(np.real(np.ones(1, dtype=numpytype_from_strtype(numpytype)))).dtype
4472
return torchtype

0 commit comments

Comments
 (0)