Skip to content

Commit b8e35b2

Browse files
authored
Annotation as NewType dtcwt.py
1 parent ae77c86 commit b8e35b2

File tree

1 file changed

+6
-10
lines changed

1 file changed

+6
-10
lines changed

pylops/signalprocessing/dtcwt.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
__all__ = ["DTCWT"]
22

3-
from typing import Any, Union
3+
from typing import Any, NewType, Union
44

55
import numpy as np
66

@@ -19,6 +19,8 @@
1919
else:
2020
pyramid_type = Any
2121

22+
PyramidType = NewType("PyramidType", pyramid_type)
23+
2224

2325
class DTCWT(LinearOperator):
2426
r"""Dual-Tree Complex Wavelet Transform
@@ -146,16 +148,12 @@ def _nd_to_2d(self, arr_nd: NDArray) -> NDArray:
146148
arr_2d = arr_nd.reshape(self.dims[self.axis], -1).squeeze()
147149
return arr_2d
148150

149-
def _coeff_to_array(
150-
self, pyr: pyramid_type
151-
) -> NDArray: # cannot use dtcwt types as it may not be installed
151+
def _coeff_to_array(self, pyr: PyramidType) -> NDArray:
152152
highpass_coeffs = np.vstack([h for h in pyr.highpasses])
153153
coeffs = np.concatenate((highpass_coeffs, pyr.lowpass), axis=0)
154154
return coeffs
155155

156-
def _array_to_coeff(
157-
self, X: NDArray
158-
) -> pyramid_type: # cannot use dtcwt types as it may not be installed
156+
def _array_to_coeff(self, X: NDArray) -> PyramidType:
159157
lowpass = (X[-self.lowpass_size :].real).reshape((-1, self.otherdims))
160158
_ptr = 0
161159
highpasses = ()
@@ -166,9 +164,7 @@ def _array_to_coeff(
166164
highpasses += (_h,)
167165
return dtcwt.Pyramid(lowpass, highpasses)
168166

169-
def get_pyramid(
170-
self, x: NDArray
171-
) -> pyramid_type: # cannot use dtcwt types as it may not be installed
167+
def get_pyramid(self, x: NDArray) -> PyramidType:
172168
"""Return Pyramid object from flat real-valued array"""
173169
return self._array_to_coeff(x[0] + 1j * x[1])
174170

0 commit comments

Comments
 (0)