11__all__ = ["DTCWT" ]
22
3- from typing import Union
3+ from typing import Any , NewType , Union
44
55import numpy as np
66
1515if dtcwt_message is None :
1616 import dtcwt
1717
18+ pyramid_type = dtcwt .numpy .common .Pyramid
19+ else :
20+ pyramid_type = Any
21+
22+ PyramidType = NewType ("PyramidType" , pyramid_type )
23+
1824
1925class DTCWT (LinearOperator ):
2026 r"""Dual-Tree Complex Wavelet Transform
@@ -122,7 +128,11 @@ def __init__(
122128 name = name ,
123129 )
124130
125- def _interpret_coeffs (self , dims , axis ):
131+ def _interpret_coeffs (
132+ self ,
133+ dims : Union [int , InputDimsLike ],
134+ axis : int ,
135+ ) -> None :
126136 x = np .ones (dims [axis ])
127137 pyr = self ._transform .forward (
128138 x , nlevels = self .level , include_scale = self .include_scale
@@ -134,16 +144,16 @@ def _interpret_coeffs(self, dims, axis):
134144 self .highpass_sizes .append (_h .size )
135145 self .coeff_array_size += _h .size
136146
137- def _nd_to_2d (self , arr_nd ) :
147+ def _nd_to_2d (self , arr_nd : NDArray ) -> NDArray :
138148 arr_2d = arr_nd .reshape (self .dims [self .axis ], - 1 ).squeeze ()
139149 return arr_2d
140150
141- def _coeff_to_array (self , pyr ): # cannot use dtcwt types as it may not be installed
151+ def _coeff_to_array (self , pyr : PyramidType ) -> NDArray :
142152 highpass_coeffs = np .vstack ([h for h in pyr .highpasses ])
143153 coeffs = np .concatenate ((highpass_coeffs , pyr .lowpass ), axis = 0 )
144154 return coeffs
145155
146- def _array_to_coeff (self , X ): # cannot use dtcwt types as it may not be installed
156+ def _array_to_coeff (self , X : NDArray ) -> PyramidType :
147157 lowpass = (X [- self .lowpass_size :].real ).reshape ((- 1 , self .otherdims ))
148158 _ptr = 0
149159 highpasses = ()
@@ -154,7 +164,7 @@ def _array_to_coeff(self, X): # cannot use dtcwt types as it may not be install
154164 highpasses += (_h ,)
155165 return dtcwt .Pyramid (lowpass , highpasses )
156166
157- def get_pyramid (self , x ): # cannot use dtcwt types as it may not be installed
167+ def get_pyramid (self , x : NDArray ) -> PyramidType :
158168 """Return Pyramid object from flat real-valued array"""
159169 return self ._array_to_coeff (x [0 ] + 1j * x [1 ])
160170
0 commit comments