88"""
99
1010from typing import Union , Optional
11- from collections .abc import Iterable
11+ from collections .abc import Iterable , MutableSequence
1212
1313import astropy .units as u
1414import numpy as np
1515from astropy .convolution import Gaussian2DKernel
1616from astropy .units import Quantity
17- from numpy .typing import NDArray
17+ from numpy .typing import ArrayLike , NDArray
1818from scipy import signal
1919from scipy .ndimage import shift
2020from sunpy .map .map_factory import Map
@@ -258,7 +258,7 @@ def ms_clean(
258258 dirty_map : Quantity ,
259259 dirty_beam : Quantity ,
260260 pixel_size : Quantity [u .arcsec / u .pix ],
261- scales : Union [ Iterable , NDArray , None ] = None ,
261+ scales : Optional [ MutableSequence [ int ] ] = None ,
262262 clean_beam_width : Quantity = 4.0 * u .arcsec ,
263263 gain : float = 0.1 ,
264264 thres : float = 0.01 ,
@@ -281,9 +281,8 @@ def ms_clean(
281281 scale_sizes : NDArray [np .int_ ] = 2 ** np .arange (number_of_scales )
282282
283283 if scales :
284- scales = np .array (scales )
285284 number_of_scales = len (scales )
286- scale_sizes = scales
285+ scale_sizes = scales [:]
287286
288287 scale_sizes = np .where (scale_sizes == 0 , 1 , scale_sizes )
289288
@@ -298,10 +297,10 @@ def ms_clean(
298297
299298 # Pre-compute scales, residual maps and dirty beams at each scale and dirty beam cross terms
300299 scales = np .zeros ((dirty_map .shape [0 ], dirty_map .shape [1 ], number_of_scales ))
301- scaled_residuals = np .zeros ((dirty_map .shape [0 ], dirty_map .shape [1 ], number_of_scales ))
302- scaled_dirty_beams = np .zeros ((dirty_beam .shape [0 ], dirty_beam .shape [1 ], number_of_scales ))
303- max_scaled_dirty_beams = np .zeros (number_of_scales )
304- cross_terms = {}
300+ scaled_residuals : NDArray = np .zeros ((dirty_map .shape [0 ], dirty_map .shape [1 ], number_of_scales ))
301+ scaled_dirty_beams : NDArray = np .zeros ((dirty_beam .shape [0 ], dirty_beam .shape [1 ], number_of_scales ))
302+ max_scaled_dirty_beams : NDArray = np .zeros (number_of_scales )
303+ cross_terms : dict [ tuple [ int , int ], ArrayLike ] = {}
305304
306305 for i , scale in enumerate (scale_sizes ):
307306 scales [:, :, i ] = _component (scale = scale , shape = dirty_map .shape )
0 commit comments