Skip to content

Commit ae77c86

Browse files
committed
doc: added safe typing to dtcwt
1 parent 7b8727a commit ae77c86

File tree

1 file changed

+20
-6
lines changed

1 file changed

+20
-6
lines changed

pylops/signalprocessing/dtcwt.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
__all__ = ["DTCWT"]
22

3-
from typing import Union
3+
from typing import Any, Union
44

55
import numpy as np
66

@@ -15,6 +15,10 @@
1515
if dtcwt_message is None:
1616
import dtcwt
1717

18+
pyramid_type = dtcwt.numpy.common.Pyramid
19+
else:
20+
pyramid_type = Any
21+
1822

1923
class DTCWT(LinearOperator):
2024
r"""Dual-Tree Complex Wavelet Transform
@@ -122,7 +126,11 @@ def __init__(
122126
name=name,
123127
)
124128

125-
def _interpret_coeffs(self, dims, axis):
129+
def _interpret_coeffs(
130+
self,
131+
dims: Union[int, InputDimsLike],
132+
axis: int,
133+
) -> None:
126134
x = np.ones(dims[axis])
127135
pyr = self._transform.forward(
128136
x, nlevels=self.level, include_scale=self.include_scale
@@ -134,16 +142,20 @@ def _interpret_coeffs(self, dims, axis):
134142
self.highpass_sizes.append(_h.size)
135143
self.coeff_array_size += _h.size
136144

137-
def _nd_to_2d(self, arr_nd):
145+
def _nd_to_2d(self, arr_nd: NDArray) -> NDArray:
138146
arr_2d = arr_nd.reshape(self.dims[self.axis], -1).squeeze()
139147
return arr_2d
140148

141-
def _coeff_to_array(self, pyr): # cannot use dtcwt types as it may not be installed
149+
def _coeff_to_array(
150+
self, pyr: pyramid_type
151+
) -> NDArray: # cannot use dtcwt types as it may not be installed
142152
highpass_coeffs = np.vstack([h for h in pyr.highpasses])
143153
coeffs = np.concatenate((highpass_coeffs, pyr.lowpass), axis=0)
144154
return coeffs
145155

146-
def _array_to_coeff(self, X): # cannot use dtcwt types as it may not be installed
156+
def _array_to_coeff(
157+
self, X: NDArray
158+
) -> pyramid_type: # cannot use dtcwt types as it may not be installed
147159
lowpass = (X[-self.lowpass_size :].real).reshape((-1, self.otherdims))
148160
_ptr = 0
149161
highpasses = ()
@@ -154,7 +166,9 @@ def _array_to_coeff(self, X): # cannot use dtcwt types as it may not be install
154166
highpasses += (_h,)
155167
return dtcwt.Pyramid(lowpass, highpasses)
156168

157-
def get_pyramid(self, x): # cannot use dtcwt types as it may not be installed
169+
def get_pyramid(
170+
self, x: NDArray
171+
) -> pyramid_type: # cannot use dtcwt types as it may not be installed
158172
"""Return Pyramid object from flat real-valued array"""
159173
return self._array_to_coeff(x[0] + 1j * x[1])
160174

0 commit comments

Comments
 (0)