diff --git a/nwbwidgets/file.py b/nwbwidgets/file.py index 8b036a44..d7b36c6d 100644 --- a/nwbwidgets/file.py +++ b/nwbwidgets/file.py @@ -56,4 +56,7 @@ def show_nwbfile(nwbfile: NWBFile, neurodata_vis_spec: dict) -> widgets.Widget: func_ = partial(view.nwb2widget, neurodata_vis_spec=neurodata_vis_spec) accordion = lazy_show_over_data(neuro_data, func_, labels=labels, style=widgets.Accordion) - return widgets.VBox(info + [accordion]) + from .time_grid import TimeGrid + + time_grid_widget = TimeGrid(nwbfile) + return widgets.VBox(info + [time_grid_widget] + [accordion]) diff --git a/nwbwidgets/time_grid.py b/nwbwidgets/time_grid.py new file mode 100644 index 00000000..edad125e --- /dev/null +++ b/nwbwidgets/time_grid.py @@ -0,0 +1,273 @@ +import matplotlib.cm as mplcm +import matplotlib.colors as colors +import matplotlib.pyplot as plt +import plotly.graph_objects as go +from hdmf.common import DynamicTable, VectorData +from pynwb import NWBFile, TimeSeries + + +def get_partial_path_from_parents(obj): + # ndx-events do not have object.data + parents = [] + while obj is not None and obj.name != "root": + parents.append(obj.name) + obj = obj.parent # assuming the object has a 'parent' attribute + return parents + + +def extract_timing_information_from_nwbfile(nwbfile: NWBFile, verbose: bool = False) -> dict: + temporal_information_dict = {} + for object_id, object in nwbfile.objects.items(): + object_name = object.name + if hasattr(object, "data") and hasattr(object.data, "name"): + object_path = object.data.name + top_module = object_path[1:].split("/")[0] + if top_module == "processing": + top_module = object_path[1:].split("/")[1] + + elif hasattr(object, "data"): + object_path = get_partial_path_from_parents(object)[::-1] + top_module = object_path[0] if object_path else "" + else: + object_path = get_partial_path_from_parents(object)[::-1] + top_module = object_path[0] if object_path else "" + + if isinstance(object, DynamicTable) and hasattr(object, "start_time") and hasattr(object, "stop_time"): + if verbose: + print("------------------") + print(f"branch: dynamic table with intervals for object {object_name} with path {object_path}") + print(f"Top module {top_module}") + print("------------------") + + start_time = object.start_time.data[:] + stop_time = object.stop_time.data[:] + + intervals = [ + {"start_time": start_time, "stop_time": stop_time} + for start_time, stop_time in zip(start_time, stop_time) + ] + timing_dict = dict(start_time=intervals[0]["start_time"], stop_time=intervals[-1]["stop_time"]) + object_info = dict( + object_path=object_path, object_name=object_name, object_id=object_id, top_module=top_module + ) + temporal_information_dict[object_name] = dict( + timing_dict=timing_dict, info=object_info, intervals=intervals + ) + + elif isinstance(object, DynamicTable) and hasattr(object, "start_time"): + if verbose: + print("------------------") + print("NOT DONE YET") + print(f"Branch of dynamic table with only start_time for object {object_name} with path {object_path}") + print(f"Top module {top_module}") + print("------------------") + + elif isinstance(object, TimeSeries): + time_series = object + if verbose: + print("------------------") + print(f"branch: time series for object {object_name} with path {object_path}") + print(f"Top module {top_module}") + print("------------------") + timing_dict = {} + if time_series.timestamps is not None: + timing_dict["start_time"] = time_series.timestamps[0] + timing_dict["stop_time"] = time_series.timestamps[-1] + else: + timing_dict["start_time"] = time_series.starting_time + timing_dict["stop_time"] = time_series.starting_time + time_series.num_samples / time_series.rate + + # TODO This fails when the series has timestamps with gaps. One option to find out the gaps and write the series + # As intervals is that the naive way np.diff generates a too large memory allocation. + # Probably, a bisection algorithm can be used to find the gaps. + + object_info = dict( + object_path=object_path, object_name=object_name, object_id=object_id, top_module=top_module + ) + temporal_information_dict[object_name] = dict(timing_dict=timing_dict, info=object_info) + + return temporal_information_dict + + +def generate_epoch_guides_trace(temporal_data_dict): + if "epochs" not in temporal_data_dict: + return [] + + epoch_dict = temporal_data_dict["epochs"] + intervals = epoch_dict["intervals"] + x_epoch = [] + y_epoch = [] + for i, epoch in enumerate(intervals, start=1): + start_time = epoch["start_time"] + stop_time = epoch["stop_time"] + x_epoch.extend([start_time, start_time, None]) + x_epoch.extend([stop_time, stop_time, None]) + y_epoch.extend([0.9, len(temporal_data_dict) + 1, None]) + y_epoch.extend([0.9, len(temporal_data_dict) + 1, None]) + + trace = go.Scatter( + x=x_epoch, + y=y_epoch, + mode="lines", + line=dict(color="#808080", width=1.5, dash="dash"), + name="epochs_guides", + visible=False, + ) + + return trace + + +def generate_epoch_annotations(temporal_data_dict) -> list: + if "epochs" not in temporal_data_dict: + return [] + + epoch_dict = temporal_data_dict["epochs"] + intervals = epoch_dict["intervals"] + + annotation_dict_list = [] + for i, epoch in enumerate(intervals, start=1): + start_time = epoch["start_time"] + stop_time = epoch["stop_time"] + + # Add a text box for each epoch + annotation_x = start_time + (stop_time - start_time) / 2 + annotation_y = len(temporal_data_dict) + len(temporal_data_dict) * 0.05 + annotation_dict = dict( + x=annotation_x, + y=annotation_y, + text=f"Epoch {i}", + showarrow=False, + font=dict(size=8, color="#000000"), + align="center", + borderwidth=2, + borderpad=4, + opacity=0.6, + ) + annotation_dict_list.append(annotation_dict) + return annotation_dict_list + + +def generate_total_duration_traces(temporal_data_dict, styling_dict): + traces_dict = dict() + for i, object_name in enumerate(styling_dict, start=1): + times_info_dict = temporal_data_dict[object_name] + info_dict = times_info_dict["info"] + timing_dict = times_info_dict["timing_dict"] + start_time = timing_dict["start_time"] + stop_time = timing_dict["stop_time"] + top_module = info_dict["top_module"] + color = styling_dict[object_name]["color"] + + trace = go.Scatter( + x=[start_time, stop_time], + y=[i, i], + mode="lines", + line=dict(color=color, width=6), + name=f"{top_module} - {object_name}", + legendgroup=top_module, # Group by top_module + hovertemplate=f"{object_name}
start: {start_time}
end: {stop_time}", + ) + traces_dict[object_name] = trace + + return traces_dict + + +def generate_time_grid_widget(temporal_data_dict: dict) -> go.FigureWidget: + fig = go.FigureWidget() + + # Key modules for temporal organization ad trials and epochs + key_modules = ["epochs", "trials"] + sorted_keys = [key for key in temporal_data_dict.keys() if key not in key_modules] + sorted_keys.sort(key=lambda x: temporal_data_dict[x]["info"]["top_module"]) + unique_top_modules = list(set([temporal_data_dict[key]["info"]["top_module"] for key in sorted_keys])) + + # if "epochs" in temporal_data_dict: + # sorted_keys = ["epochs"] + sorted_keys + # if "trials" in temporal_data_dict: + # sorted_keys = sorted_keys + ["trials"] + + # Generate a color map for the top modules + num_top_modules = len(unique_top_modules) + cm = plt.get_cmap("Accent") # Use 'gist_rainbow' colormap + cNorm = colors.Normalize(vmin=0, vmax=num_top_modules - 1) + scalarMap = mplcm.ScalarMappable(norm=cNorm, cmap=cm) + color_map = {module: colors.rgb2hex(scalarMap.to_rgba(i)) for i, module in enumerate(unique_top_modules)} + + styling_dict_epochs = dict(epochs=dict(color="#d3d3d3")) + styling_dict_rest = { + key: dict(color=color_map[temporal_data_dict[key]["info"]["top_module"]]) for key in sorted_keys + } + styling_dict_trials = dict(trials=dict(color="#d3d3d3")) + + styling_dict = dict() + if "trials" in temporal_data_dict: + styling_dict.update(styling_dict_trials) + + styling_dict.update(styling_dict_rest) + + if "epochs" in temporal_data_dict: + styling_dict.update(styling_dict_epochs) + + traces_dict = generate_total_duration_traces(temporal_data_dict, styling_dict) + for trace in traces_dict.values(): + fig.add_trace(trace) + + epoch_guide_trace = generate_epoch_guides_trace(temporal_data_dict) + if epoch_guide_trace: + fig.add_trace(epoch_guide_trace) + + annotation_dict_list = generate_epoch_annotations(temporal_data_dict) + # Button for toggling epoch guides + epoch_guide_button = dict( + type="buttons", + direction="down", + showactive=True, + buttons=[ + dict( + label="Toggle Epoch Guides", + method="update", + args=[{"visible": [True] * len(traces_dict) + [False]}, {"annotations": []}], + args2=[{"visible": [True] * len(traces_dict) + [True]}, {"annotations": annotation_dict_list}], + ), + ], + x=0.10, # Position from left (0 to 1) + y=-0.15, # Position from bottom (0 to 1) + ) + if epoch_guide_trace: + updatemenus = [epoch_guide_button] + else: + updatemenus = None + ticktext = list(styling_dict.keys()) + tickvals = list(range(1, len(temporal_data_dict) + 1)) + fig.update_layout( + xaxis_title="Time (s)", + yaxis=dict( + tickvals=tickvals, + ticktext=ticktext, + range=[0.5, len(ticktext) * 1.25], + ticks="", + showline=False, + showticklabels=True, + showgrid=True, + gridwidth=0.1, + ), + updatemenus=updatemenus, + ) + + return fig + + +from ipywidgets import Layout, fixed, widgets + + +class TimeGrid(widgets.VBox): + def __init__(self, nwbfile: NWBFile): + super().__init__() + + self.nwbfile = nwbfile + # This is the calculation part + self.temporal_data_dict = extract_timing_information_from_nwbfile(nwbfile=self.nwbfile) + + # This is the figure widget + self.figure_widget = generate_time_grid_widget(temporal_data_dict=self.temporal_data_dict) + self.children = [self.figure_widget]