1- # -*- coding: utf-8 -*-
21"""
32pysteps.cascade.bandpass_filters
43================================
@@ -64,10 +63,14 @@ def filter_uniform(shape, n):
6463 n: int
6564 Not used. Needed for compatibility with the filter interface.
6665
66+ Returns
67+ -------
68+ out: dict
69+ A dictionary containing the filter.
6770 """
6871 del n # Unused
6972
70- result = {}
73+ out = {}
7174
7275 try :
7376 height , width = shape
@@ -76,17 +79,23 @@ def filter_uniform(shape, n):
7679
7780 r_max = int (max (width , height ) / 2 ) + 1
7881
79- result ["weights_1d" ] = np .ones ((1 , r_max ))
80- result ["weights_2d" ] = np .ones ((1 , height , int (width / 2 ) + 1 ))
81- result ["central_freqs" ] = None
82- result ["central_wavenumbers" ] = None
83- result ["shape" ] = shape
82+ out ["weights_1d" ] = np .ones ((1 , r_max ))
83+ out ["weights_2d" ] = np .ones ((1 , height , int (width / 2 ) + 1 ))
84+ out ["central_freqs" ] = None
85+ out ["central_wavenumbers" ] = None
86+ out ["shape" ] = shape
8487
85- return result
88+ return out
8689
8790
8891def filter_gaussian (
89- shape , n , l_0 = 3 , gauss_scale = 0.5 , gauss_scale_0 = 0.5 , d = 1.0 , normalize = True
92+ shape ,
93+ n ,
94+ gauss_scale = 0.5 ,
95+ d = 1.0 ,
96+ normalize = True ,
97+ return_weight_funcs = False ,
98+ include_mean = True ,
9099):
91100 """
92101 Implements a set of Gaussian bandpass filters in logarithmic frequency
@@ -99,20 +108,20 @@ def filter_gaussian(
99108 the domain is assumed to have square shape.
100109 n: int
101110 The number of frequency bands to use. Must be greater than 2.
102- l_0: int
103- Central frequency of the second band (the first band is always centered
104- at zero).
105111 gauss_scale: float
106- Optional scaling prameter . Proportional to the standard deviation of
112+ Optional scaling parameter . Proportional to the standard deviation of
107113 the Gaussian weight functions.
108- gauss_scale_0: float
109- Optional scaling parameter for the Gaussian function corresponding to
110- the first frequency band.
111114 d: scalar, optional
112115 Sample spacing (inverse of the sampling rate). Defaults to 1.
113116 normalize: bool
114117 If True, normalize the weights so that for any given wavenumber
115118 they sum to one.
119+ return_weight_funcs: bool
120+ If True, add callable weight functions to the output dictionary with
121+ the key 'weight_funcs'.
122+ include_mean: bool
123+ If True, include the first Fourier wavenumber (corresponding to the
124+ field mean) to the first filter.
116125
117126 Returns
118127 -------
@@ -133,6 +142,8 @@ def filter_gaussian(
133142 except TypeError :
134143 height , width = (shape , shape )
135144
145+ max_length = max (width , height )
146+
136147 rx = np .s_ [: int (width / 2 ) + 1 ]
137148
138149 if (height % 2 ) == 1 :
@@ -145,13 +156,13 @@ def filter_gaussian(
145156
146157 r_2d = np .roll (np .sqrt (x_grid * x_grid + y_grid * y_grid ), dy , axis = 0 )
147158
148- max_length = max (width , height )
149-
150159 r_max = int (max_length / 2 ) + 1
151160 r_1d = np .arange (r_max )
152161
153162 wfs , central_wavenumbers = _gaussweights_1d (
154- max_length , n , l_0 = l_0 , gauss_scale = gauss_scale , gauss_scale_0 = gauss_scale_0
163+ max_length ,
164+ n ,
165+ gauss_scale = gauss_scale ,
155166 )
156167
157168 weights_1d = np .empty ((n , r_max ))
@@ -168,36 +179,48 @@ def filter_gaussian(
168179 weights_1d [k , :] /= weights_1d_sum
169180 weights_2d [k , :, :] /= weights_2d_sum
170181
171- result = {"weights_1d" : weights_1d , "weights_2d" : weights_2d }
172- result ["shape" ] = shape
182+ for i in range (len (wfs )):
183+ if i == 0 and include_mean :
184+ weights_1d [i , 0 ] = 1.0
185+ weights_2d [i , 0 , 0 ] = 1.0
186+ else :
187+ weights_1d [i , 0 ] = 0.0
188+ weights_2d [i , 0 , 0 ] = 0.0
189+
190+ out = {"weights_1d" : weights_1d , "weights_2d" : weights_2d }
191+ out ["shape" ] = shape
173192
174193 central_wavenumbers = np .array (central_wavenumbers )
175- result ["central_wavenumbers" ] = central_wavenumbers
194+ out ["central_wavenumbers" ] = central_wavenumbers
176195
177196 # Compute frequencies
178197 central_freqs = 1.0 * central_wavenumbers / max_length
179198 central_freqs [0 ] = 1.0 / max_length
180199 central_freqs [- 1 ] = 0.5 # Nyquist freq
181200 central_freqs = 1.0 * d * central_freqs
182- result ["central_freqs" ] = central_freqs
201+ out ["central_freqs" ] = central_freqs
202+
203+ if return_weight_funcs :
204+ out ["weight_funcs" ] = wfs
183205
184- return result
206+ return out
185207
186208
187- def _gaussweights_1d (l , n , l_0 = 3 , gauss_scale = 0.5 , gauss_scale_0 = 0.5 ):
188- e = pow (0.5 * l / l_0 , 1.0 / (n - 2 ))
189- r = [(l_0 * pow (e , k - 1 ), l_0 * pow (e , k )) for k in range (1 , n - 1 )]
209+ def _gaussweights_1d (l , n , gauss_scale = 0.5 ):
210+ q = pow (0.5 * l , 1.0 / n )
211+ r = [(pow (q , k - 1 ), pow (q , k )) for k in range (1 , n + 1 )]
212+ r = [0.5 * (r_ [0 ] + r_ [1 ]) for r_ in r ]
190213
191214 def log_e (x ):
192215 if len (np .shape (x )) > 0 :
193216 res = np .empty (x .shape )
194217 res [x == 0 ] = 0.0
195- res [x > 0 ] = np .log (x [x > 0 ]) / np .log (e )
218+ res [x > 0 ] = np .log (x [x > 0 ]) / np .log (q )
196219 else :
197220 if x == 0.0 :
198221 res = 0.0
199222 else :
200- res = np .log (x ) / np .log (e )
223+ res = np .log (x ) / np .log (q )
201224
202225 return res
203226
@@ -211,25 +234,11 @@ def __call__(self, x):
211234 return np .exp (- (x ** 2.0 ) / (2.0 * self .s ** 2.0 ))
212235
213236 weight_funcs = []
214- central_wavenumbers = [0.0 ]
215-
216- weight_funcs .append (GaussFunc (0.0 , gauss_scale_0 ))
237+ central_wavenumbers = []
217238
218239 for i , ri in enumerate (r ):
219- rc = log_e (ri [ 0 ] )
240+ rc = log_e (ri )
220241 weight_funcs .append (GaussFunc (rc , gauss_scale ))
221- central_wavenumbers .append (ri [0 ])
222-
223- gf = GaussFunc (log_e (l / 2 ), gauss_scale )
224-
225- def g (x ):
226- res = np .ones (x .shape )
227- mask = x <= l / 2
228- res [mask ] = gf (x [mask ])
229-
230- return res
231-
232- weight_funcs .append (g )
233- central_wavenumbers .append (l / 2 )
242+ central_wavenumbers .append (ri )
234243
235244 return weight_funcs , central_wavenumbers
0 commit comments