11__all__ = ["DTCWT" ]
22
3- from typing import Any , Union
3+ from typing import Any , NewType , Union
44
55import numpy as np
66
1919else :
2020 pyramid_type = Any
2121
22+ PyramidType = NewType ("PyramidType" , pyramid_type )
23+
2224
2325class DTCWT (LinearOperator ):
2426 r"""Dual-Tree Complex Wavelet Transform
@@ -146,16 +148,12 @@ def _nd_to_2d(self, arr_nd: NDArray) -> NDArray:
146148 arr_2d = arr_nd .reshape (self .dims [self .axis ], - 1 ).squeeze ()
147149 return arr_2d
148150
149- def _coeff_to_array (
150- self , pyr : pyramid_type
151- ) -> NDArray : # cannot use dtcwt types as it may not be installed
151+ def _coeff_to_array (self , pyr : PyramidType ) -> NDArray :
152152 highpass_coeffs = np .vstack ([h for h in pyr .highpasses ])
153153 coeffs = np .concatenate ((highpass_coeffs , pyr .lowpass ), axis = 0 )
154154 return coeffs
155155
156- def _array_to_coeff (
157- self , X : NDArray
158- ) -> pyramid_type : # cannot use dtcwt types as it may not be installed
156+ def _array_to_coeff (self , X : NDArray ) -> PyramidType :
159157 lowpass = (X [- self .lowpass_size :].real ).reshape ((- 1 , self .otherdims ))
160158 _ptr = 0
161159 highpasses = ()
@@ -166,9 +164,7 @@ def _array_to_coeff(
166164 highpasses += (_h ,)
167165 return dtcwt .Pyramid (lowpass , highpasses )
168166
169- def get_pyramid (
170- self , x : NDArray
171- ) -> pyramid_type : # cannot use dtcwt types as it may not be installed
167+ def get_pyramid (self , x : NDArray ) -> PyramidType :
172168 """Return Pyramid object from flat real-valued array"""
173169 return self ._array_to_coeff (x [0 ] + 1j * x [1 ])
174170
0 commit comments