11__all__ = ["DTCWT" ]
22
3- from typing import Union
3+ from typing import Any , Union
44
55import numpy as np
66
1515if dtcwt_message is None :
1616 import dtcwt
1717
18+ pyramid_type = dtcwt .numpy .common .Pyramid
19+ else :
20+ pyramid_type = Any
21+
1822
1923class DTCWT (LinearOperator ):
2024 r"""Dual-Tree Complex Wavelet Transform
@@ -122,7 +126,11 @@ def __init__(
122126 name = name ,
123127 )
124128
125- def _interpret_coeffs (self , dims , axis ):
129+ def _interpret_coeffs (
130+ self ,
131+ dims : Union [int , InputDimsLike ],
132+ axis : int ,
133+ ) -> None :
126134 x = np .ones (dims [axis ])
127135 pyr = self ._transform .forward (
128136 x , nlevels = self .level , include_scale = self .include_scale
@@ -134,16 +142,20 @@ def _interpret_coeffs(self, dims, axis):
134142 self .highpass_sizes .append (_h .size )
135143 self .coeff_array_size += _h .size
136144
137- def _nd_to_2d (self , arr_nd ) :
145+ def _nd_to_2d (self , arr_nd : NDArray ) -> NDArray :
138146 arr_2d = arr_nd .reshape (self .dims [self .axis ], - 1 ).squeeze ()
139147 return arr_2d
140148
141- def _coeff_to_array (self , pyr ): # cannot use dtcwt types as it may not be installed
149+ def _coeff_to_array (
150+ self , pyr : pyramid_type
151+ ) -> NDArray : # cannot use dtcwt types as it may not be installed
142152 highpass_coeffs = np .vstack ([h for h in pyr .highpasses ])
143153 coeffs = np .concatenate ((highpass_coeffs , pyr .lowpass ), axis = 0 )
144154 return coeffs
145155
146- def _array_to_coeff (self , X ): # cannot use dtcwt types as it may not be installed
156+ def _array_to_coeff (
157+ self , X : NDArray
158+ ) -> pyramid_type : # cannot use dtcwt types as it may not be installed
147159 lowpass = (X [- self .lowpass_size :].real ).reshape ((- 1 , self .otherdims ))
148160 _ptr = 0
149161 highpasses = ()
@@ -154,7 +166,9 @@ def _array_to_coeff(self, X): # cannot use dtcwt types as it may not be install
154166 highpasses += (_h ,)
155167 return dtcwt .Pyramid (lowpass , highpasses )
156168
157- def get_pyramid (self , x ): # cannot use dtcwt types as it may not be installed
169+ def get_pyramid (
170+ self , x : NDArray
171+ ) -> pyramid_type : # cannot use dtcwt types as it may not be installed
158172 """Return Pyramid object from flat real-valued array"""
159173 return self ._array_to_coeff (x [0 ] + 1j * x [1 ])
160174
0 commit comments