diff --git a/nwbwidgets/allen.py b/nwbwidgets/allen.py index d48dd196..f406105b 100644 --- a/nwbwidgets/allen.py +++ b/nwbwidgets/allen.py @@ -19,10 +19,6 @@ def make_group_and_sort(self, group_by=None, control_order=False): ) -class AllenPSTHWidget(TimeIntervalsSelector): - InnerWidget = PSTHWidget - - class AllenRasterGridWidget(TimeIntervalsSelector): InnerWidget = RasterGridWidget @@ -85,7 +81,6 @@ def allen_show_electrodes(node: DynamicTable): def load_allen_widgets(): default_neurodata_vis_spec[Units]["Session Raster"] = AllenRasterWidget - default_neurodata_vis_spec[Units]["Grouped PSTH"] = AllenPSTHWidget default_neurodata_vis_spec[Units]["Raster Grid"] = AllenRasterGridWidget default_neurodata_vis_spec[Units]["Tuning Curves"] = AllenTuningCurveWidget # default_neurodata_vis_spec[DynamicTable] = allen_show_dynamic_table diff --git a/nwbwidgets/base.py b/nwbwidgets/base.py index 8a22acfa..a875c704 100644 --- a/nwbwidgets/base.py +++ b/nwbwidgets/base.py @@ -12,6 +12,7 @@ from nwbwidgets import view from pynwb import ProcessingModule from pynwb.core import NWBDataInterface, MultiContainerInterface +from pynwb.epoch import TimeIntervals from ipywidgets.widgets.interaction import show_inline_matplotlib_plots @@ -371,46 +372,61 @@ def row_to_hover_text(row): return "
".join(text_rows) -class TimeIntervalsSelector(widgets.VBox): - InnerWidget = None - - def __init__(self, input_data, **kwargs): +class TimeIntervalsSelectorMixin: + """ + Sister class must have intervals_selector_callback + """ + def set_interval_selector(self, intervals, nwbfile=None): """ - Creates a TimeInterval controller that controls InnerWidget. + If a string is given, look up that table in nwb.intervals + If a pynwb.epoch.TimeIntervals object is given, use that + If a Dropdown is given, use that as a selector + If there is no input for intervals, look at nwb.intervals + If nwb.intervals has 0 entries, render a placeholder + If nwb.intervals has 1 entry, use it + If nwb.intervals has more than one entry, create a dropdown of all available intervals Parameters ---------- - input_data: pynwb object - Pynwb object (e.g. pynwb.misc.Units) belonging to a nwbfile - that will be filtered by the TimeIntervalSelector controller. + intervals: str or pynwb.epoch.TimeIntervals or ipywidgets.DropDown or None + nwbfile: pynwb.NWBFile + + Returns + ------- + """ - super().__init__() - self.input_data = input_data - self.kwargs = kwargs - self.intervals_tables = input_data.get_ancestor("NWBFile").intervals - self.stimulus_type_dd = widgets.Dropdown( - options=list(self.intervals_tables.keys()), - description="stimulus type" - ) - self.stimulus_type_dd.observe(self.stimulus_type_dd_callback) - - trials = list(self.intervals_tables.values())[0] - inner_widget = self.InnerWidget( - units=self.input_data, - trials=trials, - **kwargs - ) - self.children = [self.stimulus_type_dd, inner_widget] - - def stimulus_type_dd_callback(self, change): - self.children = [self.stimulus_type_dd, widgets.HTML("Rendering...")] - trials = self.intervals_tables[self.stimulus_type_dd.value] - inner_widget = self.InnerWidget( - input_data=self.input_data, - trials=trials, - **self.kwargs - ) - self.children = [self.stimulus_type_dd, inner_widget] + self.intervals_dropdown = None + + if isinstance(intervals, str): + if intervals == "trials": + self.intervals = nwbfile.trials + elif intervals not in nwbfile.intervals: + raise ValueError("'{intervals}' not in NWBFile.intervals") + self.intervals = nwbfile.intervals[intervals] + elif isinstance(intervals, widgets.Dropdown): + self.intervals = nwbfile.intervals[self.intervals_dropdown.value] + self.intervals_dropdown.observe(self.intervals_selector_callback) + elif isinstance(intervals, TimeIntervals): + self.intervals = intervals + elif intervals is None: + all_intervals_tables = nwbfile.intervals + trials = nwbfile.trials + if trials is not None: + all_intervals_tables.add(trials) + if len(all_intervals_tables) == 0: + self.children = [HTML("could not find intervals")] + return + elif len(all_intervals_tables) == 1: + self.intervals = list(all_intervals_tables.values())[0] + else: + self.intervals_dropdown = widgets.Dropdown( + options=list(all_intervals_tables), + description="intervals", + ) + self.intervals_dropdown.observe(self.intervals_selector_callback) + self.intervals = list(all_intervals_tables.values())[0] + else: + raise ValueError("intervals is not an allowable type") def show_multi_container_interface( diff --git a/nwbwidgets/misc.py b/nwbwidgets/misc.py index 91730801..1e384fda 100644 --- a/nwbwidgets/misc.py +++ b/nwbwidgets/misc.py @@ -6,12 +6,13 @@ import plotly.graph_objects as go import pynwb import scipy -from ipywidgets import widgets, fixed, FloatProgress, Layout +from ipywidgets import widgets, fixed, FloatProgress, Layout, HTML from matplotlib.collections import PatchCollection from matplotlib.patches import Rectangle from pynwb.misc import AnnotationSeries, Units, DecompositionSeries from .analysis.spikes import compute_smoothed_firing_rate +from .base import TimeIntervalsSelectorMixin from .controllers import ( make_trial_event_controller, GroupAndSortController, @@ -267,27 +268,39 @@ def control_plot(x0, x1, ch0, ch1): return vbox -class PSTHWidget(widgets.VBox): +class PSTHWidget(widgets.VBox, TimeIntervalsSelectorMixin): def __init__( - self, - input_data: Units, - trials: pynwb.epoch.TimeIntervals = None, - unit_index=0, - unit_controller=None, - ntt=1000, + self, + units: Units, + intervals: str = None, + unit_index=0, + unit_controller=None, + ntt=1000, ): + """ + + Parameters + ---------- + input_data: pynwb.Units + intervals: str, optional + If a string is given, look up that table in nwb.intervals + If a TimeIntervals object is given, use that + If a Dropdown is given, use that as a selector + If there is no input for intervals, look at nwb.intervals + If nwb.intervals has 0 entries, render a placeholder + If nwb.intervals has 1 entry, use it + If nwb.intervals has more than one entry, create a dropdown of all available intervals + unit_index: int + unit_controller + ntt: int + """ - self.units = input_data + self.units = units + self.ntt = ntt super().__init__() - if trials is None: - self.trials = self.get_trials() - if self.trials is None: - self.children = [widgets.HTML("No trials present")] - return - else: - self.trials = trials + self.set_interval_selector(intervals, units.get_ancestor("NWBFile")) if unit_controller is None: self.unit_ids = self.units.id.data[:] @@ -301,8 +314,16 @@ def __init__( else: self.unit_controller = unit_controller + self.refresh_intervals() + + def make_group_and_sort(self, window=None, control_order=False): + return GroupAndSortController( + self.intervals, window=window, control_order=control_order + ) + + def refresh_intervals(self): self.trial_event_controller = make_trial_event_controller( - self.trials, layout=Layout(width="200px"), multiple=True + self.intervals, layout=Layout(width="200px"), multiple=True ) self.start_ft = widgets.FloatText( -0.5, step=0.1, description="start (s)", layout=Layout(width="200px"), @@ -332,7 +353,7 @@ def __init__( self.gas = self.make_group_and_sort(window=False, control_order=False) self.controls = dict( - ntt=fixed(ntt), + ntt=fixed(self.ntt), index=self.unit_controller, end=self.end_ft, start=self.start_ft, @@ -346,7 +367,7 @@ def __init__( out_fig = interactive_output(self.update, self.controls) - self.children = [ + children = [ widgets.HBox( [ widgets.VBox( @@ -373,13 +394,16 @@ def __init__( out_fig, ] - def get_trials(self): - return self.units.get_ancestor("NWBFile").trials + if self.intervals_dropdown is not None: + children.insert(0, self.intervals_dropdown) - def make_group_and_sort(self, window=None, control_order=False): - return GroupAndSortController( - self.trials, window=window, control_order=control_order - ) + self.children = children + + + def intervals_selector_callback(self, change): + self.children = [self.intervals_dropdown, widgets.HTML("Rendering...")] + self.intervals = self.units.get_ancestor("NWBFile").intervals[self.intervals_dropdown.value] + self.refresh_intervals() def update( self, @@ -441,7 +465,7 @@ def update( data = align_by_time_intervals( self.units, index, - self.trials, + self.intervals, start_label, start_label, start, @@ -485,7 +509,7 @@ def update( expanded_data = align_by_time_intervals( units=self.units, index=index, - intervals=self.trials, + intervals=self.intervals, start_label=start_label, stop_label=start_label, start=start - sigma_in_secs * 4, @@ -514,6 +538,7 @@ def update( ) ax1.set_xlim([start, end]) + ax1.set_xticks([start, end]) if i_s == 0: ax1.set_ylabel("firing rate (Hz)", fontsize=12) ax1.set_xlabel("time (s)", fontsize=12) @@ -1167,7 +1192,7 @@ class RasterGridWidget(widgets.VBox): def __init__( self, units: Units, - trials: pynwb.epoch.TimeIntervals = None, + intervals: pynwb.epoch.TimeIntervals = None, unit_index=0, units_trials_controller=None, ): @@ -1177,7 +1202,7 @@ def __init__( if not units_trials_controller: units_trials_controller = UnitsAndTrialsControllerWidget( units=units, - trials=trials, + trials=intervals, unit_index=unit_index ) self.children = [units_trials_controller] @@ -1195,7 +1220,7 @@ class TuningCurveWidget(widgets.VBox): def __init__( self, units: Units, - trials: pynwb.epoch.TimeIntervals = None, + intervals: pynwb.epoch.TimeIntervals = None, unit_index=0, units_trials_controller=None, ): @@ -1206,7 +1231,7 @@ def __init__( if not units_trials_controller: units_trials_controller = UnitsAndTrialsControllerWidget( units=units, - trials=trials, + trials=intervals, unit_index=unit_index ) self.children = [units_trials_controller] @@ -1239,7 +1264,7 @@ def __init__( # Tuning curve widget self.tuning_curve = TuningCurveWidget( units=units, - trials=trials, + intervals=trials, unit_index=unit_index, units_trials_controller=self.units_trials_controller, ) @@ -1247,7 +1272,7 @@ def __init__( # Raster grid widget self.raster_grid = RasterGridWidget( units=units, - trials=trials, + intervals=trials, unit_index=unit_index, units_trials_controller=self.units_trials_controller, ) diff --git a/nwbwidgets/utils/units.py b/nwbwidgets/utils/units.py index 963ec748..f78fdfc3 100644 --- a/nwbwidgets/utils/units.py +++ b/nwbwidgets/utils/units.py @@ -116,7 +116,7 @@ def align_by_time_intervals( index, intervals, start_label="start_time", - stop_label="stop_time", + stop_label=None, start=0.0, end=0.0, rows_select=(), diff --git a/test/test_misc.py b/test/test_misc.py index 2571ff25..46c0bc2c 100644 --- a/test/test_misc.py +++ b/test/test_misc.py @@ -18,7 +18,7 @@ ) from pynwb import NWBFile from pynwb.misc import DecompositionSeries, AnnotationSeries - +from pynwb.epoch import TimeIntervals def test_show_psth(): data = np.random.random([6, 50]) @@ -93,11 +93,34 @@ def test_psth_widget(self): def test_multipsth_widget(self): psth_widget = PSTHWidget(self.nwbfile.units) - assert isinstance(psth_widget, widgets.Widget) start_labels = ('start_time', 'stop_time') fig = psth_widget.update(index=0, start_labels=start_labels) assert len(fig.axes) == 2 * len(start_labels) - + + def test_multiple_intervals(self): + time_intervals = TimeIntervals("custom_intervals_table") + time_intervals.add_row(start_time=1., stop_time=2.) + time_intervals.add_row(start_time=2.5, stop_time=3.5) + self.nwbfile.intervals.add(time_intervals) + widget = PSTHWidget(self.nwbfile.units) + assert widget.intervals_dropdown is not None + + def test_input_intervals_trials_name(self): + + widget = PSTHWidget(self.nwbfile.units, intervals="trials") + assert widget.intervals.name == "trials" + assert widget.intervals_dropdown is None + + def test_input_intervals_object(self): + + time_intervals = TimeIntervals("custom_intervals_table") + time_intervals.add_row(start_time=1., stop_time=2.) + time_intervals.add_row(start_time=2.5, stop_time=3.5) + widget = PSTHWidget(self.nwbfile.units, intervals=time_intervals) + assert widget.intervals.name == "custom_intervals_table" + assert widget.intervals_dropdown is None + + def test_raster_widget(self): assert isinstance(RasterWidget(self.nwbfile.units), widgets.Widget) diff --git a/test/test_utils_units.py b/test/test_utils_units.py index 7c962a15..d5950073 100644 --- a/test/test_utils_units.py +++ b/test/test_utils_units.py @@ -114,7 +114,7 @@ def setUp(self): super().setUp() self.widget = TuningCurveWidget( units=self.nwbfile.units, - trials=self.nwbfile.trials + intervals=self.nwbfile.trials ) # rows controller triggers drawing of graphic self.widget.children[0].children[1].value = 'stim'