Skip to content

Commit 1f6e8f5

Browse files
committed
compute theil-sen estimator manually (exact and faster)
1 parent 0630977 commit 1f6e8f5

File tree

1 file changed

+18
-13
lines changed

1 file changed

+18
-13
lines changed

src/spikeinterface/metrics/template/metrics.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -283,22 +283,27 @@ def sort_template_and_locations(template, channel_locations, depth_direction="y"
283283
sort_indices = np.argsort(channel_locations[:, depth_dim])
284284
return template[:, sort_indices], channel_locations[sort_indices, :]
285285

286-
287-
def fit_line_robust(x, y):
286+
def fit_line_robust(x, y, eps=1e-12):
288287
"""
289288
Fit line using robust Theil-Sen estimator (median of pairwise slopes).
290289
"""
291-
import sklearn.linear_model as lm
292-
293-
# Center data to improve numerical stability
294-
X = (x - x.mean()).reshape(-1, 1)
295-
y = y - y.mean()
296-
297-
theil = lm.TheilSenRegressor()
298-
theil.fit(X, y)
299-
slope = theil.coef_[0]
300-
score = theil.score(X, y) # R^2 score
301-
return slope, score
290+
import itertools
291+
292+
# Calculate slope and bias using Theil-Sen estimator
293+
slopes = []
294+
for (xs, ys) in zip(itertools.combinations(x, 2), itertools.combinations(y, 2)):
295+
if np.abs(xs[0] - xs[1]) > eps:
296+
slopes.append((ys[1] - ys[0]) / (xs[1] - xs[0]))
297+
if len(slopes) == 0: # all x are identical
298+
return np.nan, -np.inf
299+
slope = np.median(slopes)
300+
bias = np.median(y - slope * x)
301+
302+
# Calculate R2 score
303+
y_pred = slope * x + bias
304+
r2_score = 1 - ((y - y_pred)**2).sum() / (((y - y.mean())**2).sum() + eps)
305+
306+
return slope, r2_score
302307

303308

304309
def get_velocity_fits(template, channel_locations, sampling_frequency, **kwargs):

0 commit comments

Comments
 (0)