Skip to content

Commit ad3d9cc

Browse files
authored
Merge pull request #575 from mrava87/dev
doc: added safe typing to dtcwt
2 parents 99d91f1 + b8e35b2 commit ad3d9cc

File tree

1 file changed

+16
-6
lines changed

1 file changed

+16
-6
lines changed

pylops/signalprocessing/dtcwt.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
__all__ = ["DTCWT"]
22

3-
from typing import Union
3+
from typing import Any, NewType, Union
44

55
import numpy as np
66

@@ -15,6 +15,12 @@
1515
if dtcwt_message is None:
1616
import dtcwt
1717

18+
pyramid_type = dtcwt.numpy.common.Pyramid
19+
else:
20+
pyramid_type = Any
21+
22+
PyramidType = NewType("PyramidType", pyramid_type)
23+
1824

1925
class DTCWT(LinearOperator):
2026
r"""Dual-Tree Complex Wavelet Transform
@@ -122,7 +128,11 @@ def __init__(
122128
name=name,
123129
)
124130

125-
def _interpret_coeffs(self, dims, axis):
131+
def _interpret_coeffs(
132+
self,
133+
dims: Union[int, InputDimsLike],
134+
axis: int,
135+
) -> None:
126136
x = np.ones(dims[axis])
127137
pyr = self._transform.forward(
128138
x, nlevels=self.level, include_scale=self.include_scale
@@ -134,16 +144,16 @@ def _interpret_coeffs(self, dims, axis):
134144
self.highpass_sizes.append(_h.size)
135145
self.coeff_array_size += _h.size
136146

137-
def _nd_to_2d(self, arr_nd):
147+
def _nd_to_2d(self, arr_nd: NDArray) -> NDArray:
138148
arr_2d = arr_nd.reshape(self.dims[self.axis], -1).squeeze()
139149
return arr_2d
140150

141-
def _coeff_to_array(self, pyr): # cannot use dtcwt types as it may not be installed
151+
def _coeff_to_array(self, pyr: PyramidType) -> NDArray:
142152
highpass_coeffs = np.vstack([h for h in pyr.highpasses])
143153
coeffs = np.concatenate((highpass_coeffs, pyr.lowpass), axis=0)
144154
return coeffs
145155

146-
def _array_to_coeff(self, X): # cannot use dtcwt types as it may not be installed
156+
def _array_to_coeff(self, X: NDArray) -> PyramidType:
147157
lowpass = (X[-self.lowpass_size :].real).reshape((-1, self.otherdims))
148158
_ptr = 0
149159
highpasses = ()
@@ -154,7 +164,7 @@ def _array_to_coeff(self, X): # cannot use dtcwt types as it may not be install
154164
highpasses += (_h,)
155165
return dtcwt.Pyramid(lowpass, highpasses)
156166

157-
def get_pyramid(self, x): # cannot use dtcwt types as it may not be installed
167+
def get_pyramid(self, x: NDArray) -> PyramidType:
158168
"""Return Pyramid object from flat real-valued array"""
159169
return self._array_to_coeff(x[0] + 1j * x[1])
160170

0 commit comments

Comments
 (0)