Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 152 additions & 0 deletions src/spikeinterface/widgets/unit_spatial.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
from __future__ import annotations

import numpy as np
from probeinterface import Probe
from probeinterface.plotting import get_auto_lims
from seaborn import color_palette
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you should move these to the plot_matplotlib function. Also seaborn is not installed by default. Could you use a matplotlib palette?

Copy link
Contributor Author

@FrancescoNegri FrancescoNegri Jul 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can definitely use a matplotlib palette, but I was thinking to add an optional seaborn import (if available) to enable also seaborn palettes. I used seaborn also in other parts of the function, I will try to fix it, though I am not sure I will be able to easily plot kernel desnity estimates with matplotlib only.

Is there a specific reason to move probeinterface imports to the plot_matplotlib function? I use its get_auto_lims function to compute xrange and yrange, that are independent on the visualization backend.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes that souns good! The reason are the failing tests ;)

Core tests only install minimal dependencies. Upon collecting tests across modules, if something is not installed it'll throw an error.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I checked the core tests. It seems that probeinterface is installed, but matplotlib is not, thus probeinterface.plotting cannot be imported. Is that right?

from warnings import warn
from .base import BaseWidget, to_attr


class UnitSpatialDistributionsWidget(BaseWidget):
"""
Placeholder documentation to be changed.

Parameters
----------
sorting_analyzer : SortingAnalyzer
The SortingAnalyzer object
depth_axis : int, default: 1
The dimension of unit_locations that is depth
"""

def __init__(
self,
sorting_analyzer,
probe=None,
depth_axis=1,
bins=None,
cmap="viridis",
kde=False,
depth_hist=True,
groups=None,
kde_kws=None,
backend=None,
**backend_kwargs,
):
sorting_analyzer = self.ensure_sorting_analyzer(sorting_analyzer)

self.check_extensions(sorting_analyzer, "unit_locations")
ulc = sorting_analyzer.get_extension("unit_locations")
unit_locations = ulc.get_data(outputs="numpy")
x, y = unit_locations[:, 0], unit_locations[:, 1]

if type(probe) is Probe:
if sorting_analyzer.recording.has_probe():
# TODO: throw warning saying that sorting_analyzer has a probe and it will be overwritten
pass
elif sorting_analyzer.recording.has_probe():
probe = sorting_analyzer.get_probe()
else:
# TODO: throw error or warning, no probe available
pass

xrange, yrange, _ = get_auto_lims(probe, margin=0)
if bins is None:
bins = (
np.round(np.diff(xrange).squeeze() / 75).astype(int),
np.round(np.diff(yrange).squeeze() / 75).astype(int),
)
# TODO: change behaviour, if bins is not defined, bin only along the depth axis

if type(cmap) is str:
cmap = color_palette(cmap, as_cmap=True)

plot_data = dict(
probe=probe,
x=x,
y=y,
depth_axis=depth_axis,
xrange=xrange,
yrange=yrange,
bins=bins,
kde=kde,
cmap=cmap,
depth_hist=depth_hist,
groups=groups,
kde_kws=kde_kws,
)

BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs)

def plot_matplotlib(self, data_plot, **backend_kwargs):
import matplotlib.patches as patches
import matplotlib.path as path
from seaborn import kdeplot, histplot
from .utils_matplotlib import make_mpl_figure

dp = to_attr(data_plot)

self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs)

ax = self.ax

custom_shape = path.Path(dp.probe.probe_planar_contour)
patch = patches.PathPatch(custom_shape, facecolor="none", edgecolor="none")
ax.add_patch(patch)

if dp.kde is not True:
hist, xedges, yedges = np.histogram2d(dp.x, dp.y, bins=dp.bins, range=[dp.xrange, dp.yrange])
pcm = ax.pcolormesh(xedges, yedges, hist.T, cmap=dp.cmap)
else:
kde_kws = dict(levels=100, thresh=0, fill=True, bw_adjust=0.1)
if dp.kde_kws is not None:
kde_kws.update(dp.kde_kws)
data = dict(x=dp.x, y=dp.y)
bg = ax.add_patch(
patches.Rectangle(
[dp.xrange[0], dp.yrange[0]],
np.diff(dp.xrange).squeeze(),
np.diff(dp.yrange).squeeze(),
facecolor=dp.cmap.colors[0],
fill=True,
)
)
bg.set_clip_path(patch)
kdeplot(data, x="x", y="y", clip=[dp.xrange, dp.yrange], cmap=dp.cmap, ax=ax, **kde_kws)
pcm = ax.collections[0]
ax.set_xlabel(None)
ax.set_ylabel(None)

pcm.set_clip_path(patch)

xlim, ylim, _ = get_auto_lims(dp.probe, margin=10)
ax.set_xlim(*xlim)
ax.set_ylim(*ylim)
ax.spines["top"].set_visible(False)
ax.spines["bottom"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.set_xticks([])
ax.set_xlabel("")
ax.set_ylabel("Depth (um)")

if dp.depth_hist is True:
bbox = ax.get_window_extent()
hist_height = 1.5 * bbox.width

ax_hist = ax.inset_axes([1, 0, hist_height / bbox.width, 1])
data = dict(y=dp.y)
data["group"] = np.ones(dp.y.size) if dp.groups is None else dp.groups
palette = color_palette("bright", n_colors=1 if dp.groups is None else np.unique(dp.groups).size)
histplot(
data=data,
y="y",
hue="group",
bins=dp.bins[1],
binrange=dp.yrange,
palette=palette,
ax=ax_hist,
legend=False,
)
ax_hist.axis("off")
ax_hist.set_ylim(*ylim)
3 changes: 3 additions & 0 deletions src/spikeinterface/widgets/widget_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from .unit_locations import UnitLocationsWidget
from .unit_presence import UnitPresenceWidget
from .unit_probe_map import UnitProbeMapWidget
from .unit_spatial import UnitSpatialDistributionsWidget
from .unit_summary import UnitSummaryWidget
from .unit_templates import UnitTemplatesWidget
from .unit_waveforms_density_map import UnitWaveformDensityMapWidget
Expand Down Expand Up @@ -67,6 +68,7 @@
UnitLocationsWidget,
UnitPresenceWidget,
UnitProbeMapWidget,
UnitSpatialDistributionsWidget,
UnitSummaryWidget,
UnitTemplatesWidget,
UnitWaveformDensityMapWidget,
Expand Down Expand Up @@ -142,6 +144,7 @@
plot_unit_locations = UnitLocationsWidget
plot_unit_presence = UnitPresenceWidget
plot_unit_probe_map = UnitProbeMapWidget
plot_unit_spatial_distribution = UnitSpatialDistributionsWidget
plot_unit_summary = UnitSummaryWidget
plot_unit_templates = UnitTemplatesWidget
plot_unit_waveforms_density_map = UnitWaveformDensityMapWidget
Expand Down