diff --git a/nwbwidgets/image.py b/nwbwidgets/image.py index bb6e599a..79de6ca5 100644 --- a/nwbwidgets/image.py +++ b/nwbwidgets/image.py @@ -91,56 +91,46 @@ def on_change(change): self.controls["time_window"].observe(on_change) -def show_image_series(image_series: ImageSeries, neurodata_vis_spec: dict): - if len(image_series.data.shape) == 3: - return show_grayscale_image_series(image_series, neurodata_vis_spec) +class ImageSeriesWidget(widgets.VBox): + + def __int__(self, image_series: ImageSeries, neurodata_vis_spec: dict = None): + self.image_series = image_series + + self.index_slider = widgets.IntSlider( + value=0, + min=0, + max=image_series.data.shape[0] - 1, + orientation="horizontal", + continuous_update=False, + description="index", + ) + + if len(image_series.data.shape) == 3: + self.show_image = self.show_grayscale_image + self.controls = {"index": self.index_slider} + out_fig = widgets.interactive_output(self.show_image, self.controls) + super().__init__(children=(out_fig, self.index_slider)) + else: + self.show_image = self.show_rgb_image + self.mode_dropdown = widgets.Dropdown( + options=("rgb", "bgr"), layout=Layout(width="200px"), description="mode" + ) + self.controls = {"index": self.index_slider, "mode": self.mode_dropdown} + out_fig = widgets.interactive_output(self.show_image, self.controls) + super.__init__(children=(out_fig, self.index_slider, self.mode_dropdown)) - def show_image(index=0, mode="rgb"): + def show_grayscale_image(self, index=0): fig, ax = plt.subplots(subplot_kw={"xticks": [], "yticks": []}) - image = image_series.data[index] - if mode == "bgr": - image = image[:, :, ::-1] - ax.imshow(image.transpose([1, 0, 2]), cmap="gray", aspect="auto") - fig.show() + ax.imshow(self.image_series.data[index].T, cmap="gray", aspect="auto") return fig2widget(fig) - slider = widgets.IntSlider( - value=0, - min=0, - max=image_series.data.shape[0] - 1, - orientation="horizontal", - continuous_update=False, - description="index", - ) - mode = widgets.Dropdown( - options=("rgb", "bgr"), layout=Layout(width="200px"), description="mode" - ) - controls = {"index": slider, "mode": mode} - out_fig = widgets.interactive_output(show_image, controls) - vbox = widgets.VBox(children=[out_fig, slider, mode]) - - return vbox - - -def show_grayscale_image_series(image_series: ImageSeries, neurodata_vis_spec: dict): - def show_image(index=0): + def show_rgb_image(self, index=0, mode="rgb"): fig, ax = plt.subplots(subplot_kw={"xticks": [], "yticks": []}) - ax.imshow(image_series.data[index].T, cmap="gray", aspect="auto") - return fig - - slider = widgets.IntSlider( - value=0, - min=0, - max=image_series.data.shape[0] - 1, - orientation="horizontal", - continuous_update=False, - description="index", - ) - controls = {"index": slider} - out_fig = widgets.interactive_output(show_image, controls) - vbox = widgets.VBox(children=[out_fig, slider]) - - return vbox + image = self.image_series.data[index] + if mode == "bgr": + image = image[:, :, ::-1] + ax.imshow(image.transpose([1, 0, 2]), aspect="auto") + return fig2widget(fig) def show_index_series(index_series, neurodata_vis_spec: dict): @@ -148,7 +138,7 @@ def show_index_series(index_series, neurodata_vis_spec: dict): series_widget = show_timeseries(index_series) indexed_timeseries = index_series.indexed_timeseries - image_series_widget = show_image_series(indexed_timeseries, neurodata_vis_spec) + image_series_widget = ImageSeriesWidget(indexed_timeseries, neurodata_vis_spec) return widgets.VBox([series_widget, image_series_widget]) diff --git a/test/test_controllers.py b/test/test_controllers.py index 76514efc..33081ff7 100644 --- a/test/test_controllers.py +++ b/test/test_controllers.py @@ -3,7 +3,7 @@ from hdmf.common import DynamicTable, VectorData from pynwb.ecephys import ElectrodeGroup, Device -from nwbwidgets.controllers import RangeController, GroupAndSortController +from nwbwidgets.controllers import RangeController, GroupAndSortController, StartAndDurationController class FloatRangeControllerTestCase(unittest.TestCase): @@ -72,3 +72,11 @@ def test_control(self): gas.order_dd.value = "Data1" gas.order_dd.value = None + + +class TestStartAndDurationController(unittest.TestCase): + def setUp(self) -> None: + self.start_and_duration_controller = StartAndDurationController(10) + + def test_set_duration(self): + self.start_and_duration_controller.duration.value = 2 diff --git a/test/test_image.py b/test/test_image.py index 65315cc2..7c158df9 100644 --- a/test/test_image.py +++ b/test/test_image.py @@ -1,3 +1,5 @@ +import unittest + import ipywidgets as widgets import matplotlib.pyplot as plt import numpy as np @@ -5,7 +7,7 @@ show_rbga_image, show_grayscale_image, show_index_series, - show_image_series, + ImageSeriesWidget ) from nwbwidgets.view import default_neurodata_vis_spec from pynwb.base import TimeSeries @@ -47,10 +49,14 @@ def test_show_index_series(): ) -def test_show_image_series(): - data = np.random.rand(800).reshape((8, 10, 10)) - image_series = ImageSeries(name="Image Series", data=data, rate=1.0, unit='n.a.') +class TestImageSeriesWidget(unittest.TestCase): - assert isinstance( - show_image_series(image_series, default_neurodata_vis_spec), widgets.Widget - ) + def test_grascale(self): + data = np.random.rand(800).reshape((8, 10, 10)) + image_series = ImageSeries(name="Image Series", data=data, rate=1.0, unit='n.a.') + widget = ImageSeriesWidget(image_series) + + def test_rgb(self): + data = np.random.rand(800 * 3).reshape((8, 10, 10,3)) + image_series = ImageSeries(name="Image Series", data=data, rate=1.0, unit='n.a.') + widget = ImageSeriesWidget(image_series)