Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 16 additions & 8 deletions src/spikeinterface/widgets/unit_locations.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from __future__ import annotations

import warnings
import numpy as np
from probeinterface import ProbeGroup

from spikeinterface.core.template_tools import get_template_extremum_channel
from spikeinterface.core.sortinganalyzer import SortingAnalyzer

from .base import BaseWidget, to_attr
from .utils import get_unit_colors
from spikeinterface.core.sortinganalyzer import SortingAnalyzer
import numpy as np


class UnitLocationsWidget(BaseWidget):
Expand Down Expand Up @@ -60,13 +63,13 @@ def __init__(
all_unit_locations = ulc.get_data()

x_locations = all_unit_locations[:, 0]
x_min = np.min(x_locations)
x_max = np.max(x_locations)
x_min = np.nanmin(x_locations)
x_max = np.nanmax(x_locations)
x_lim = (x_min - margin, x_max + margin)

y_locations = all_unit_locations[:, 1]
y_min = np.min(y_locations)
y_max = np.max(y_locations)
y_min = np.nanmin(y_locations)
y_max = np.nanmax(y_locations)
y_lim = (y_min - margin, y_max + margin)

sorting = sorting_analyzer.sorting
Expand All @@ -81,6 +84,13 @@ def __init__(
if unit_ids is None:
unit_ids = sorting.unit_ids

if np.any(np.isnan(all_unit_locations[sorting.ids_to_indices(unit_ids)])):
warnings.warn("Some unit locations contain NaN values. Replacing with extremum channel location.")
extremum_channel_indices = get_template_extremum_channel(sorting_analyzer, outputs="index")
for unit_id in unit_ids:
if np.any(np.isnan(unit_locations[unit_id])):
unit_locations[unit_id] = channel_locations[extremum_channel_indices[unit_id]]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should make a copy of unit_locaions in this case to avoid inplace replacement in the analyzer


data_plot = dict(
all_unit_ids=sorting.unit_ids,
unit_locations=unit_locations,
Expand Down Expand Up @@ -110,9 +120,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
from matplotlib.lines import Line2D

dp = to_attr(data_plot)
# backend_kwargs = self.update_backend_kwargs(**backend_kwargs)

# self.make_mpl_figure(**backend_kwargs)
self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs)

unit_locations = dp.unit_locations
Expand Down
7 changes: 7 additions & 0 deletions src/spikeinterface/widgets/unit_summary.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from __future__ import annotations
from collections import defaultdict

import warnings
import numpy as np

from spikeinterface.core.template_tools import get_template_extremum_channel
from .base import BaseWidget, to_attr
from .utils import get_unit_colors

Expand Down Expand Up @@ -134,8 +136,13 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
col_counter += 1

unit_locations = sorting_analyzer.get_extension("unit_locations").get_data(outputs="by_unit")
extremum_channel_indices = get_template_extremum_channel(sorting_analyzer, outputs="index")
unit_location = unit_locations[unit_id]
x, y = unit_location[0], unit_location[1]
if np.isnan(x) or np.isnan(y):
warnings.warn(f"Unit {unit_id} location contains NaN values. Replacing NaN extremum channel location.")
x, y = sorting_analyzer.get_channel_locations()[extremum_channel_indices[unit_id]]

ax_unit_locations.set_xlim(x - 80, x + 80)
ax_unit_locations.set_ylim(y - 250, y + 250)
ax_unit_locations.set_xticks([])
Expand Down
Loading