Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 92 additions & 5 deletions phy/apps/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,8 @@ def on_close_view(view_, gui):


class FeatureMixin(object):
# Spike attributes that can be used for visualization in addition to the features.
_spike_attributes = ('amplitudes', 'depths')
n_spikes_features = 2500
n_spikes_features_background = 2500

Expand All @@ -286,22 +288,107 @@ class FeatureMixin(object):
)

_cached = (
'_get_features',
'get_spike_feature_amplitudes',
)

_memcached = (
# '_get_features_for_view',
# 'get_spike_attributes_for_views',
)

# This property provides a consistent public interface for views to get feature data,
# abstracting away the underlying implementation.
get_features = property(lambda self: self._get_features_for_view)

def _get_feature_spike_ids(self, cluster_id, load_all=False):
"""Return spike ids to be used in the feature view."""
if load_all:
return self.supervisor.get_spike_ids(cluster_id)
# Background spikes.
if cluster_id is None:
return self.selector(self.n_spikes_features_background, [])
# Spikes in a cluster.
return self.selector(self.n_spikes_features, [cluster_id])

def _get_features_for_view(self, cluster_ids, channel_ids=None, load_all=False):
"""Get features for a list of clusters.

This function is the main entry point for views to retrieve feature data.
It handles fetching data for both background spikes and specific clusters,
and determines the appropriate channels to use if not specified.
"""
if self.model.features is None:
return

# Special case for background spikes.
if cluster_ids is None:
spike_ids = self._get_feature_spike_ids(None, load_all=load_all)
if spike_ids is None or not len(spike_ids):
return
features = self.model.features[spike_ids, ...]
# We need to specify the channel ids, which are all channels in this case.
b_channel_ids = np.arange(self.model.channel_positions.shape[0])
b = Bunch(
data=features,
spike_ids=spike_ids,
channel_ids=b_channel_ids,
cluster_id=None,
)
# This is a list of bunches.
return [b]

bunchs = []
for cluster_id in cluster_ids:
spike_ids = self._get_feature_spike_ids(cluster_id, load_all=load_all)
if spike_ids is None or not len(spike_ids):
continue

# If channel_ids are not provided, get the best channels for the cluster.
if channel_ids is None:
c_ids = self.get_best_channels(cluster_id)
else:
c_ids = channel_ids

# Get the features for the specified channels.
features_bunch = self._get_spike_features(spike_ids, c_ids)
if not features_bunch:
continue

features_bunch.cluster_id = cluster_id
bunchs.append(features_bunch)
return bunchs

def get_spike_attributes_for_views(self):
"""Return a dictionary of functions `cluster_id => values`.

This method provides a flexible "data menu" for views. Instead of returning data
directly, it returns a dictionary of callable functions. Each function can be
invoked by a view to get a specific data attribute (e.g., depths, amplitudes)
for a cluster on demand. This design enables the creation of complex views
(like a 3D view) that require multiple independent data sources.
"""
d = {}
for name in self._spike_attributes:
# The function takes a cluster_id and returns an array.
d[name] = lambda cluster_id, name=name: getattr(
self.model, 'get_spike_%s' % name)(self._get_feature_spike_ids(cluster_id))
# Use helper that works across models (TemplateModel may not implement get_spike_times)
d['time'] = lambda cluster_id, load_all=False: self._get_feature_view_spike_times(
cluster_id, load_all=load_all)
return d

def get_spike_feature_amplitudes(
self, spike_ids, channel_id=None, channel_ids=None, pc=None, **kwargs):
"""Return the features for the specified channel and PC."""
self, spike_ids, channel_id=None, **kwargs):
"""Return the maximum amplitude of the features on one channel."""
if self.model.features is None:
return
channel_id = channel_id if channel_id is not None else channel_ids[0]
features = self._get_spike_features(spike_ids, [channel_id]).get('data', None)
if features is None: # pragma: no cover
return
assert features.shape[0] == len(spike_ids)
logger.log(5, "Show channel %s and PC %s in amplitude view.", channel_id, pc)
return features[:, 0, pc or 0]
logger.log(5, "Show channel %s and PC 0 in amplitude view.", channel_id)
return features[:, 0, 0]

def create_amplitude_view(self):
view = super(FeatureMixin, self).create_amplitude_view()
Expand Down
51 changes: 48 additions & 3 deletions phy/apps/template/gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
from phylib import _add_log_file
from phylib.io.model import TemplateModel, load_model
from phylib.io.traces import MtscompEphysReader
from phylib.utils import Bunch, connect
from phylib.utils import Bunch, connect, unconnect

from phy.cluster.views import ScatterView
from phy.cluster.views import ScatterView, Feature3DView
from phy.gui import create_app, run_app
from ..base import WaveformMixin, FeatureMixin, TemplateMixin, TraceMixin, BaseController

Expand Down Expand Up @@ -70,6 +70,7 @@ class TemplateController(WaveformMixin, FeatureMixin, TemplateMixin, TraceMixin,
'CorrelogramView',
'ISIView',
'FeatureView',
'Feature3DView',
'AmplitudeView',
'FiringRateView',
'TraceView',
Expand Down Expand Up @@ -141,6 +142,7 @@ def _get_template_features(self, cluster_ids, load_all=False):
def _set_view_creator(self):
super(TemplateController, self)._set_view_creator()
self.view_creator['TemplateFeatureView'] = self.create_template_feature_view
self.view_creator['Feature3DView'] = self.create_feature_3d_view

# Public methods
# -------------------------------------------------------------------------
Expand Down Expand Up @@ -194,6 +196,49 @@ def create_template_feature_view(self):
return
return TemplateFeatureView(coords=self._get_template_features)

def create_feature_3d_view(self):
"""Create and configure the 3D feature view.

This view requires multiple data sources to render the 3D scatter plot:
* `features`: The main feature data, typically used for the X and Y axes.
* `attributes`: A dictionary of other data vectors (like depth), used for the
Z axis and color. This is provided by `get_spike_attributes_for_views`.
* `channel_positions`: The physical layout of the probe channels.
"""
logger.debug("Creating Feature3DView")
try:
# Gather the different data sources required by the view.
features = self.get_features
attributes = self.get_spike_attributes_for_views()
channel_positions = self.model.channel_positions
logger.debug(f"Features: {features}")
logger.debug(f"Attributes: {attributes}")
logger.debug(f"Channel positions: {channel_positions.shape if channel_positions is not None else 'None'}")
view = Feature3DView(
features=features,
attributes=attributes,
channel_positions=channel_positions,
cluster_ids=self.supervisor.selected
)
logger.debug("Feature3DView created successfully")

# Connect the view to the supervisor's select event.
# This ensures the view is updated when the cluster selection changes.
@connect(sender=self.supervisor)
def on_select(sender, cluster_ids, **kwargs):
if view.auto_update:
view.on_select(cluster_ids=cluster_ids)

# Disconnect the view when it's closed to prevent memory leaks.
@connect(sender=view)
def on_close_view(view_, gui):
unconnect(on_select)

return view
except Exception as e:
logger.error(f"Error creating Feature3DView: {e}", exc_info=True)
raise


#------------------------------------------------------------------------------
# Template commands
Expand Down Expand Up @@ -226,4 +271,4 @@ def template_describe(params_path):
"""Describe a template dataset."""
model = load_model(params_path)
model.describe()
model.close()
model.close()
17 changes: 16 additions & 1 deletion phy/cluster/supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from phylib.utils import Bunch, emit, connect, unconnect
from phy.gui.actions import Actions
from phy.gui.qt import _block, set_busy, _wait
from phy.gui.qt import _block, set_busy, _wait, QMessageBox
from phy.gui.widgets import Table, HTMLWidget, _uniq, Barrier

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -1044,6 +1044,21 @@ def split(self, spike_ids=None, spike_clusters_rel=0):
out = self.clustering.split(
spike_ids, spike_clusters_rel=spike_clusters_rel)
self._global_history.action(self.clustering)

# Show a pop-up with the split information.
if out:
added = out.get('added', [])
deleted = out.get('deleted', [])
message = f"Split successful.\n\n"
if added:
message += f"New clusters created: {', '.join(map(str, added))}\n"
if deleted:
message += f"Original clusters affected: {', '.join(map(str, deleted))}"

box = QMessageBox()
box.setText(message)
box.exec_()

return out

# Move actions
Expand Down
1 change: 1 addition & 0 deletions phy/cluster/views/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .amplitude import AmplitudeView # noqa
from .correlogram import CorrelogramView # noqa
from .feature import FeatureView # noqa
from .featureview3d import Feature3DView # noqa
from .histogram import HistogramView, ISIView, FiringRateView # noqa
from .probe import ProbeView # noqa
from .raster import RasterView # noqa
Expand Down
2 changes: 2 additions & 0 deletions phy/cluster/views/feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,8 @@ def attach(self, gui):
def toggle_automatic_channel_selection(self, checked):
"""Toggle the automatic selection of channels when the cluster selection changes."""
self.fixed_channels = not checked
# The status bar needs to be updated manually to reflect the change.
self.update_status()

@property
def status(self):
Expand Down
Loading