diff --git a/src/spikeinterface/widgets/unit_locations.py b/src/spikeinterface/widgets/unit_locations.py index f5fcd60aea..c55c802f9b 100644 --- a/src/spikeinterface/widgets/unit_locations.py +++ b/src/spikeinterface/widgets/unit_locations.py @@ -1,11 +1,14 @@ from __future__ import annotations +import warnings +import numpy as np from probeinterface import ProbeGroup +from spikeinterface.core.template_tools import get_template_extremum_channel +from spikeinterface.core.sortinganalyzer import SortingAnalyzer + from .base import BaseWidget, to_attr from .utils import get_unit_colors -from spikeinterface.core.sortinganalyzer import SortingAnalyzer -import numpy as np class UnitLocationsWidget(BaseWidget): @@ -60,13 +63,13 @@ def __init__( all_unit_locations = ulc.get_data() x_locations = all_unit_locations[:, 0] - x_min = np.min(x_locations) - x_max = np.max(x_locations) + x_min = np.nanmin(x_locations) + x_max = np.nanmax(x_locations) x_lim = (x_min - margin, x_max + margin) y_locations = all_unit_locations[:, 1] - y_min = np.min(y_locations) - y_max = np.max(y_locations) + y_min = np.nanmin(y_locations) + y_max = np.nanmax(y_locations) y_lim = (y_min - margin, y_max + margin) sorting = sorting_analyzer.sorting @@ -81,6 +84,13 @@ def __init__( if unit_ids is None: unit_ids = sorting.unit_ids + if np.any(np.isnan(all_unit_locations[sorting.ids_to_indices(unit_ids)])): + warnings.warn("Some unit locations contain NaN values. Replacing with extremum channel location.") + extremum_channel_indices = get_template_extremum_channel(sorting_analyzer, outputs="index") + for unit_id in unit_ids: + if np.any(np.isnan(unit_locations[unit_id])): + unit_locations[unit_id] = channel_locations[extremum_channel_indices[unit_id]] + data_plot = dict( all_unit_ids=sorting.unit_ids, unit_locations=unit_locations, @@ -110,9 +120,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): from matplotlib.lines import Line2D dp = to_attr(data_plot) - # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) - # self.make_mpl_figure(**backend_kwargs) self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) unit_locations = dp.unit_locations diff --git a/src/spikeinterface/widgets/unit_summary.py b/src/spikeinterface/widgets/unit_summary.py index b1c1682c8a..fb26a228ef 100644 --- a/src/spikeinterface/widgets/unit_summary.py +++ b/src/spikeinterface/widgets/unit_summary.py @@ -1,8 +1,10 @@ from __future__ import annotations from collections import defaultdict +import warnings import numpy as np +from spikeinterface.core.template_tools import get_template_extremum_channel from .base import BaseWidget, to_attr from .utils import get_unit_colors @@ -134,8 +136,13 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): col_counter += 1 unit_locations = sorting_analyzer.get_extension("unit_locations").get_data(outputs="by_unit") + extremum_channel_indices = get_template_extremum_channel(sorting_analyzer, outputs="index") unit_location = unit_locations[unit_id] x, y = unit_location[0], unit_location[1] + if np.isnan(x) or np.isnan(y): + warnings.warn(f"Unit {unit_id} location contains NaN values. Replacing NaN extremum channel location.") + x, y = sorting_analyzer.get_channel_locations()[extremum_channel_indices[unit_id]] + ax_unit_locations.set_xlim(x - 80, x + 80) ax_unit_locations.set_ylim(y - 250, y + 250) ax_unit_locations.set_xticks([])