22
33import numpy as np
44
5- from . basepreprocessor import BasePreprocessor , BasePreprocessorSegment
6- from .filter import fix_dtype
7- from spikeinterface .core import order_channels_by_depth , get_chunk_with_margin
5+ from spikeinterface . preprocessing . basepreprocessor import BasePreprocessor , BasePreprocessorSegment , BaseRecording
6+ from spikeinterface . preprocessing .filter import fix_dtype
7+ from spikeinterface .core import order_channels_by_depth , get_chunk_with_margin , get_noise_levels
88from spikeinterface .core .core_tools import define_function_handling_dict_from_class
99
1010
@@ -48,8 +48,17 @@ class HighpassSpatialFilterRecording(BasePreprocessor):
4848 Order of spatial butterworth filter
4949 highpass_butter_wn : float, default: 0.01
5050 Critical frequency (with respect to Nyquist) of spatial butterworth filter
51+ epsilon : float, default: 0.003
52+ Value multiplied to RMS values to avoid division by zero during AGC.
53+ random_slice_kwargs : dict | None, default: None
54+ If not None, dictionary of arguments to be passed to `get_noise_levels` when computing
55+ noise levels.
5156 dtype : dtype, default: None
5257 The dtype of the output traces. If None, the dtype is the same as the input traces
58+ rms_values : np.ndarray | None, default: None
59+ If not None, array of RMS values for each channel to be used during AGC. If None, RMS values are computed
60+ from the recording. This is used to cache pre-computed RMS values, which are only computed once at
61+ initialization.
5362
5463 Returns
5564 -------
@@ -66,15 +75,18 @@ class HighpassSpatialFilterRecording(BasePreprocessor):
6675
6776 def __init__ (
6877 self ,
69- recording ,
78+ recording : BaseRecording ,
7079 n_channel_pad = 60 ,
7180 n_channel_taper = 0 ,
7281 direction = "y" ,
7382 apply_agc = True ,
7483 agc_window_length_s = 0.1 ,
7584 highpass_butter_order = 3 ,
7685 highpass_butter_wn = 0.01 ,
86+ epsilon = 0.003 ,
87+ random_slice_kwargs = None ,
7788 dtype = None ,
89+ rms_values = None ,
7890 ):
7991 BasePreprocessor .__init__ (self , recording )
8092
@@ -115,6 +127,14 @@ def __init__(
115127 if not apply_agc :
116128 agc_window_length_s = None
117129
130+ # Compute or retrieve RMS values
131+ if rms_values is None :
132+ if "noise_level_rms_raw" in recording .get_property_keys ():
133+ rms_values = recording .get_property ("noise_level_rms_raw" )
134+ else :
135+ random_slice_kwargs = {} if random_slice_kwargs is None else random_slice_kwargs
136+ rms_values = get_noise_levels (recording , method = "rms" , return_scaled = False , ** random_slice_kwargs )
137+
118138 # Pre-compute spatial filtering parameters
119139 butter_kwargs = dict (btype = "highpass" , N = highpass_butter_order , Wn = highpass_butter_wn )
120140 sos_filter = scipy .signal .butter (** butter_kwargs , output = "sos" )
@@ -133,6 +153,8 @@ def __init__(
133153 order_f ,
134154 order_r ,
135155 dtype = dtype ,
156+ epsilon = epsilon ,
157+ rms_values = rms_values ,
136158 )
137159 self .add_recording_segment (rec_segment )
138160
@@ -145,6 +167,7 @@ def __init__(
145167 agc_window_length_s = agc_window_length_s ,
146168 highpass_butter_order = highpass_butter_order ,
147169 highpass_butter_wn = highpass_butter_wn ,
170+ rms_values = rms_values ,
148171 )
149172
150173
@@ -161,6 +184,8 @@ def __init__(
161184 order_f ,
162185 order_r ,
163186 dtype ,
187+ epsilon ,
188+ rms_values ,
164189 ):
165190 BasePreprocessorSegment .__init__ (self , parent_recording_segment )
166191 self .parent_recording_segment = parent_recording_segment
@@ -185,6 +210,7 @@ def __init__(
185210 # get filter params
186211 self .sos_filter = sos_filter
187212 self .dtype = dtype
213+ self .epsilon_values_for_agc = epsilon * np .array (rms_values )
188214
189215 def get_traces (self , start_frame , end_frame , channel_indices ):
190216 if channel_indices is None :
@@ -207,8 +233,9 @@ def get_traces(self, start_frame, end_frame, channel_indices):
207233 traces = traces .copy ()
208234
209235 # apply AGC and keep the gains
236+ traces = traces .astype (np .float32 )
210237 if self .window is not None :
211- traces , agc_gains = agc (traces , window = self .window )
238+ traces , agc_gains = agc (traces , window = self .window , epsilons = self . epsilon_values_for_agc )
212239 else :
213240 agc_gains = None
214241 # pad the array with a mirrored version of itself and apply a cosine taper
@@ -255,36 +282,56 @@ def get_traces(self, start_frame, end_frame, channel_indices):
255282# -----------------------------------------------------------------------------------------------
256283
257284
258- def agc (traces , window , epsilon = 1e-8 ):
285+ def agc (traces , window , epsilons ):
259286 """
260287 Automatic gain control
261288 w_agc, gain = agc(w, window_length=.5, si=.002, epsilon=1e-8)
262289 such as w_agc * gain = w
263- :param traces: seismic array (sample last dimension)
264- :param window_length: window length (secs) (original default 0.5)
265- :param si: sampling interval (secs) (original default 0.002)
266- :param epsilon: whitening (useful mainly for synthetic data)
267- :return: AGC data array, gain applied to data
290+
291+ Parameters
292+ ----------
293+ traces : np.ndarray
294+ Input traces
295+ window : np.ndarray
296+ Window to use for AGC (1D array)
297+ epsilons : np.ndarray[float]
298+ Epsilon values for each channel to avoid division by zero
299+
300+ Returns
301+ -------
302+ agc_traces : np.ndarray
303+ AGC applied traces
304+ gain : np.ndarray
305+ Gain applied to the traces
268306 """
269307 import scipy .signal
270308
271309 gain = scipy .signal .fftconvolve (np .abs (traces ), window [:, None ], mode = "same" , axes = 0 )
272310
273- gain += (np .sum (gain , axis = 0 ) * epsilon / traces .shape [0 ])[np .newaxis , :]
274-
275311 dead_channels = np .sum (gain , axis = 0 ) == 0
276312
277- traces [:, ~ dead_channels ] = traces [:, ~ dead_channels ] / gain [:, ~ dead_channels ]
313+ traces [:, ~ dead_channels ] = traces [:, ~ dead_channels ] / np . maximum ( epsilons , gain [:, ~ dead_channels ])
278314
279315 return traces , gain
280316
281317
282318def fcn_extrap (x , f , bounds ):
283319 """
284320 Extrapolates a flat value before and after bounds
285- x: array to be filtered
286- f: function to be applied between bounds (cf. fcn_cosine below)
287- bounds: 2 elements list or np.array
321+
322+ Parameters
323+ ----------
324+ x : np.ndarray
325+ Input array
326+ f : function
327+ Function to be applied between bounds
328+ bounds : list or np.ndarray
329+ 2 elements list or array defining the bounds
330+
331+ Returns
332+ -------
333+ y : np.ndarray
334+ Output array
288335 """
289336 y = f (x )
290337 y [x < bounds [0 ]] = f (bounds [0 ])
@@ -298,8 +345,16 @@ def fcn_cosine(bounds):
298345 values <= bounds[0]: values
299346 values < bounds[0] < bounds[1] : cosine taper
300347 values < bounds[1]: bounds[1]
301- :param bounds:
302- :return: lambda function
348+
349+ Parameters
350+ ----------
351+ bounds : list or np.ndarray
352+ 2 elements list or array defining the bounds
353+
354+ Returns
355+ -------
356+ func : function
357+ Lambda function implementing the soft thresholding with cosine taper
303358 """
304359
305360 def _cos (x ):
0 commit comments