diff --git a/ephys/classes/experiment_objects.py b/ephys/classes/experiment_objects.py index b4d94cb..1d2ec11 100644 --- a/ephys/classes/experiment_objects.py +++ b/ephys/classes/experiment_objects.py @@ -216,7 +216,6 @@ def add_file_info( time_list.append(time_rec) time_list.sort() estimated_exp_date = time_list[0] - # NOTE: abf files have date of experiment in the header file_list.append( { "data_of_creation": time_created, diff --git a/ephys/classes/plot/__init__.py b/ephys/classes/plot/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/ephys/classes/plot/plot_params.py b/ephys/classes/plot/plot_params.py new file mode 100644 index 0000000..9482d2e --- /dev/null +++ b/ephys/classes/plot/plot_params.py @@ -0,0 +1,147 @@ +""" +This module defines the PlotParams class, which encapsulates configurable +parameters for plotting electrophysiological traces. It provides options for +customizing signal type, channels, colors, transparency, averaging, alignment, +sweep selection, plot appearance, time windows, axis limits, and theme. The +class supports validation, parameter updates, theme application, and conversion +to dictionary format for flexible and robust plotting workflows. +""" + +from typing import Any +import numpy as np + + +class PlotParams: + """Parameters for plotting traces. + + Args: + signal_type (str): Type of signal to plot ('current' or 'voltage'). + channels (np.ndarray): Channels to plot. + color (str): Color or colormap name for individual traces. + alpha (float): Transparency for individual traces. + average (bool): Whether to plot the average trace. + avg_color (str): Color for the average trace. + align_onset (bool): Whether to align traces on onset. + sweep_subset (Any): Subset of sweeps to plot. + bg_color (str): Background color for the plot. + axis_color (str): Color for axes. + window (list[tuple[float, float]]): Time windows for plotting. + window_color (str): Color for window regions. + xlim (tuple[float, float]): X-axis limits. + show (bool): Whether to show the plot. + return_fig (bool): Whether to return the figure object. + window_mode (str): Mode for handling windows + ('use_plot', 'use_trace', 'add_to_trace'). + theme (str): Theme for the plot ('dark' or 'light'). + """ + + def __init__( + self, + **kwargs: Any, + ) -> None: + """Initialize PlotParams with default values or provided kwargs.""" + self.signal_type = kwargs.get("signal_type", "") + self.channels = kwargs.get("channels", np.array([], dtype=np.int64)) + self.color = kwargs.get("color", "white") + self.alpha = kwargs.get("alpha", 1.0) + self.average = kwargs.get("average", False) + self.avg_color = kwargs.get("avg_color", "red") + self.align_onset = kwargs.get("align_onset", True) + self.sweep_subset = kwargs.get("sweep_subset", None) + self.bg_color = kwargs.get("bg_color", "black") + self.axis_color = kwargs.get("axis_color", "white") + self.window = kwargs.get("window", [(0, 0)]) + self.window_color = kwargs.get("window_color", "gray") + self.xlim = kwargs.get("xlim", (0, 0)) + self.show = kwargs.get("show", True) + self.return_fig = kwargs.get("return_fig", False) + self.window_mode = kwargs.get( + "window_mode", "add_to_trace" + ) # Default mode for handling windows + self.theme = kwargs.get("theme", "dark") + self.apply_theme(self.theme) + + def apply_theme(self, theme="dark") -> None: + """Apply the specified theme to the plot parameters. + + Args: + theme (str): Theme for the plot ('dark' or 'light'). + """ + self.theme = theme + if self.theme == "dark": + self.color = "#d0d0d0ff" + self.bg_color = "#292929" + self.axis_color = "white" + self.window_color = "#8A8A8A" + elif self.theme == "light": + self.color = "#1D1D1DFF" + self.bg_color = "#FFFFFF" + self.axis_color = "#000000" + self.window_color = "#ACACAC" + else: + raise ValueError(f"Invalid theme: {self.theme}. Must be 'dark' or 'light'.") + + def update_params(self, **kwargs: Any) -> None: + """Update the plot parameters with provided keyword arguments.""" + if "theme" in kwargs: + self.apply_theme(kwargs["theme"]) + for key, value in kwargs.items(): + if hasattr(self, key): + setattr(self, key, value) + else: + raise AttributeError(f"PlotParams has no attribute '{key}'.") + + self.validate() + + def validate(self) -> None: + """Validate the plot parameters.""" + if not isinstance(self.signal_type, str): + raise TypeError("signal_type must be a string.") + if not isinstance(self.channels, np.ndarray): + raise TypeError("channels must be a numpy ndarray.") + if not isinstance(self.window, list) or not all( + isinstance(win, tuple) and len(win) == 2 for win in self.window + ): + raise TypeError("window must be a list of tuples (start, end).") + if ( + not isinstance(self.xlim, tuple) + or len(self.xlim) > 2 + or len(self.xlim) == 1 + ): + raise TypeError("xlim must be a tuple of two values (min, max).") + if not isinstance(self.show, bool): + raise TypeError("show must be a boolean value.") + if not isinstance(self.return_fig, bool): + raise TypeError("return_fig must be a boolean value.") + if self.window_mode not in ["use_plot", "use_trace", "add_to_trace"]: + raise ValueError( + f"Invalid window_mode: {self.window_mode}. " + "Must be 'use_plot', 'use_trace', or 'add_to_trace'." + ) + + def to_dict(self) -> dict: + """Convert the plot parameters to a dictionary.""" + return { + "signal_type": self.signal_type, + "channels": self.channels.tolist(), + "color": self.color, + "alpha": self.alpha, + "average": self.average, + "avg_color": self.avg_color, + "align_onset": self.align_onset, + "sweep_subset": self.sweep_subset, + "bg_color": self.bg_color, + "axis_color": self.axis_color, + "window": self.window, + "window_color": self.window_color, + "xlim": self.xlim, + "show": self.show, + "return_fig": self.return_fig, + "window_mode": self.window_mode, + "theme": self.theme, + } + + def __iter__(self): + """Iterate over the plot parameters.""" + for key, value in self.to_dict().items(): + yield key, value diff --git a/ephys/classes/plot/plot_trace.py b/ephys/classes/plot/plot_trace.py new file mode 100644 index 0000000..7149d67 --- /dev/null +++ b/ephys/classes/plot/plot_trace.py @@ -0,0 +1,464 @@ +""" +This module contains classes for plotting traces using different backends. +It includes support for both PyQtGraph and Matplotlib, allowing for flexible +visualization of trace data +""" + +from typing import Any + +import numpy as np +import matplotlib.pyplot as plt +from matplotlib.axes import Axes +import matplotlib.colors as mcolors +import pyqtgraph as pg + +from ephys.classes.trace import Trace +from ephys.classes.class_functions import _get_sweep_subset +from ephys import utils +from ephys.classes.plot.plot_params import PlotParams + + +class TracePlot: + """Base class for plotting traces.""" + + def __init__(self, trace: Trace, backend: str = "pyqt", **kwargs) -> None: + """ + Args: + trace (Trace): The trace object to plot. + backend (str): The plotting backend to use ('pyqt' or 'matplotlib'). + **kwargs: Additional parameters for plotting. + """ + + # check backend + if backend not in ["pyqt", "matplotlib"]: + raise ValueError( + f"Unsupported backend: {backend}. Choose 'pyqt' or 'matplotlib'." + ) + self.params = PlotParams(**kwargs) + self.backend = backend + self.trace = trace + + def handle_windows(self) -> list: + """Handle window interaction between plot parameters and trace""" + + # Initialize windows to display + windows_to_display = [] + + # Case 1: Add the plot windows to trace.window + if self.params.window_mode == "add_to_trace": + # Convert input to proper format + if isinstance(self.params.window, tuple): + plot_windows = [self.params.window] + elif isinstance(self.params.window, list): + plot_windows = self.params.window + else: + raise TypeError("Window must be a tuple or list of tuples.") + + # Initialize trace.window if it doesn't exist + if self.trace.window is None: + self.trace.window = [] + + # Add new windows from plot parameters + if plot_windows != [(0, 0)]: + # Add each new window to trace.window + for win in plot_windows: + if win not in self.trace.window: + self.trace.window.append(win) + + # Use the updated trace.window for display + windows_to_display = self.trace.window + + # Case 2: Use existing trace.window + elif self.params.window_mode == "use_trace": + # Use existing trace.window for display + if self.trace.window is None or len(self.trace.window) == 0: + windows_to_display = [] + else: + windows_to_display = self.trace.window + + # Case 3: Use windows from plot without modifying trace.window + else: # self.params.window_mode == "use_plot" + if isinstance(self.params.window, tuple): + windows_to_display = [self.params.window] + elif isinstance(self.params.window, list): + windows_to_display = self.params.window + else: + raise TypeError("Window must be a tuple or list of tuples.") + return windows_to_display + + def _prepare_time_array(self, trace_select): + """Prepare time array based on alignment settings""" + if self.params.align_onset: + return trace_select.set_time( + align_to_zero=True, + cumulative=False, + stimulus_interval=0.0, + overwrite_time=False, + ) + return trace_select.time + + +class TracePlotMatplotlib(TracePlot): + """Class for plotting traces using Matplotlib.""" + + def __init__(self, trace: Trace, backend: str = "matplotlib", **kwargs) -> None: + super().__init__(trace=trace, backend=backend, **kwargs) + + def plot( + self, + **kwargs, + ) -> None | tuple: + """ + Plots the traces for the specified channels. + + Args: + signal_type (str): The type of signal_type to use. Must be either 'current' or + 'voltage'. + channels (list, optional): The list of channels to plot. If None, all channels + will be plotted. + Defaults to None. + average (bool, optional): Whether to plot the average trace. + Defaults to False. + color (str, optional): The color of the individual traces. Can be a colormap. + Defaults to 'black'. + alpha (float, optional): The transparency of the individual traces. + Defaults to 0.5. + avg_color (str, optional): The color of the average trace. + Defaults to 'red'. + align_onset (bool, optional): Whether to align the traces on the onset. + Defaults to True. + sweep_subset (Any, optional): The subset of sweeps to plot. + Defaults to None. + window (tuple, optional): The time window to plot. + Defaults to (0, 0). + show (bool, optional): Whether to display the plot. + Defaults to True. + return_fig (bool, optional): Whether to return the figure. + Defaults to False. + + Returns: + None or Figure: If show is True, returns None. If return_fig is True, + returns the figure. + """ + if kwargs: + self.params.update_params(**kwargs) + if len(self.params.channels) == 0: + self.params.channels = self.trace.channel_information.channel_number + sweep_subset = _get_sweep_subset( + array=self.trace.time, sweep_subset=self.params.sweep_subset + ) + trace_select = self.trace.subset( + channels=self.params.channels, + signal_type=self.params.signal_type, + sweep_subset=self.params.sweep_subset, + ) + + fig, channel_axs = plt.subplots(len(trace_select.channel), 1, sharex=True) + # color background and axis + + # change color of all axes + self._set_axs_color(input_axs=channel_axs) + + fig.set_facecolor(self.params.bg_color) + if isinstance(channel_axs, Axes): + channel_axs.set_facecolor(color=self.params.bg_color) + elif isinstance(channel_axs, np.ndarray): + for axs in channel_axs: + axs.set_facecolor(self.params.bg_color) + else: + raise TypeError("channel_axs must be an Axes or np.ndarray of Axes.") + if len(trace_select.channel) == 0: + print("No traces found.") + return None + + time_array = self._prepare_time_array(trace_select) + + windows_to_display = self.handle_windows() + + tmp_axs: Axes | None = None + for channel_index, channel in enumerate(trace_select.channel): + if len(trace_select.channel) == 1: + if isinstance(channel_axs, Axes): + tmp_axs = channel_axs + else: + if isinstance(channel_axs, np.ndarray): + if isinstance(channel_axs[channel_index], Axes): + tmp_axs = channel_axs[channel_index] + if tmp_axs is None: + pass + else: + for i in range(channel.data.shape[0]): + tmp_axs.plot( + time_array[i, :], + channel.data[i, :], + color=utils.trace_color( + traces=channel.data, index=i, color=self.params.color + ), + alpha=self.params.alpha, + ) + + if windows_to_display != [(0, 0)]: + if ( + isinstance(windows_to_display, list) + and len(windows_to_display) != 0 + ): + for win in windows_to_display: + if isinstance(tmp_axs, Axes): + tmp_axs.axvspan( + xmin=win[0], + xmax=win[1], + color=self.params.window_color, + alpha=0.5, + ) + if self.params.average: + channel.channel_average(sweep_subset=sweep_subset) + tmp_axs.plot( + time_array[0, :], + channel.average.trace, + color=self.params.avg_color, + ) + tmp_axs.set_ylabel( + ylabel=( + f"Channel {trace_select.channel_information.channel_number[channel_index]} " + f"({trace_select.channel_information.unit[channel_index]})" + ) + ) + if isinstance(tmp_axs, Axes): + tmp_axs.set_xlabel( + xlabel=f"Time ({trace_select.time.units.dimensionality.string})" + ) + if len(self.params.xlim) == 2: + if self.params.xlim[0] < self.params.xlim[1]: + tmp_axs.set_xlim( + left=self.params.xlim[0], right=self.params.xlim[1] + ) + else: + tmp_axs.set_xlim(left=time_array.min(), right=time_array.max()) + plt.tight_layout() + if self.params.show: + plt.show() + if self.params.return_fig: + return fig, channel_axs + return None + + def _set_axs_color(self, input_axs: Axes | np.ndarray) -> None: + """Set the background and axis color for the given axes.""" + if isinstance(input_axs, Axes): + input_axs.set_facecolor(self.params.bg_color) + input_axs.spines["bottom"].set_color(self.params.axis_color) + input_axs.spines["left"].set_color(self.params.axis_color) + # remove top and right spines + input_axs.spines["top"].set_visible(False) + input_axs.spines["right"].set_visible(False) + input_axs.tick_params(axis="x", colors=self.params.axis_color) + input_axs.tick_params(axis="y", colors=self.params.axis_color) + # title color + input_axs.title.set_color(self.params.axis_color) + input_axs.xaxis.label.set_color(self.params.axis_color) + input_axs.yaxis.label.set_color(self.params.axis_color) + elif isinstance(input_axs, np.ndarray): + for axs in input_axs: + axs.set_facecolor(self.params.bg_color) + axs.spines["bottom"].set_color(self.params.axis_color) + axs.spines["left"].set_color(self.params.axis_color) + # remove top and right spines + axs.spines["top"].set_visible(False) + axs.spines["right"].set_visible(False) + axs.tick_params(axis="x", colors=self.params.axis_color) + axs.tick_params(axis="y", colors=self.params.axis_color) + # title color + axs.title.set_color(self.params.axis_color) + axs.xaxis.label.set_color(self.params.axis_color) + axs.yaxis.label.set_color(self.params.axis_color) + else: + raise TypeError("channel_axs must be an Axes or np.ndarray of Axes.") + + +class TracePlotPyQt(TracePlot): + """Class for plotting traces using PyQtGraph.""" + + def __init__(self, trace: Trace, backend: str = "pyqt", **kwargs) -> None: + super().__init__(trace=trace, backend=backend, **kwargs) + + def plot( + self, + **kwargs: Any, + ) -> None | pg.GraphicsLayoutWidget: + """ + Plots the traces for the specified channels. + + Args: + signal_type (str): The type of signal_type to use. Must be either 'current' or + 'voltage'. + channels (list, optional): The list of channels to plot. If None, all channels + will be plotted. + Defaults to None. + average (bool, optional): Whether to plot the average trace. + Defaults to False. + color (str, optional): The color of the individual traces. Can be a colormap. + Defaults to 'black'. + alpha (float, optional): The transparency of the individual traces. + Defaults to 0.5. + avg_color (str, optional): The color of the average trace. + Defaults to 'red'. + align_onset (bool, optional): Whether to align the traces on the onset. + Defaults to True. + sweep_subset (Any, optional): The subset of sweeps to plot. + Defaults to None. + window (tuple, optional): The time window to plot. + Defaults to (0, 0). + show (bool, optional): Whether to display the plot. + Defaults to True. + return_fig (bool, optional): Whetherupdate_params to return the figure. + Defaults to False. + + Returns: + None or Figure: If show is True, returns None. If return_fig is True, + returns the figure. + """ + if kwargs: + self.params.update_params(**kwargs) + + def sync_channels(source_region, channel_items, window_index=0): + # Get region bounds from the source region + min_val, max_val = source_region.getRegion() + + # Update all other regions + for r in channel_items: + if r is not source_region: + r.blockSignals(True) + r.setRegion((min_val, max_val)) + r.blockSignals(False) + # Update the trace window property only if we're not in "use_plot" mode + if self.params.window_mode != "use_plot": + if isinstance(self.trace.window, list) and window_index < len( + self.trace.window + ): + self.trace.window[window_index] = (min_val, max_val) + else: + # Handle case where window_index is out of bounds or trace.window is not a list + pass + + def make_region_callback(region_obj, channel_items, window_index=0): + return lambda: sync_channels( + source_region=region_obj, + channel_items=channel_items, + window_index=window_index, + ) + + if len(self.params.channels) == 0: + self.params.channels = self.trace.channel_information.channel_number + trace_select = self.trace.subset( + channels=self.params.channels, + signal_type=self.params.signal_type, + sweep_subset=self.params.sweep_subset, + ) + + time_array = self._prepare_time_array(trace_select) + + if len(self.params.xlim) > 2: + raise ValueError("xlim must be a tuple of two values.") + if len(self.params.xlim) < 2 or self.params.xlim == (0, 0): + self.params.xlim = ( + np.min(time_array.magnitude), + np.max(time_array.magnitude), + ) + + win = pg.GraphicsLayoutWidget(show=self.params.show, title="Trace Plot") + win.setBackground(self.params.bg_color) + window_fill = pg.mkBrush( + color=tuple( + np.round(color_val * 255) + for color_val in mcolors.to_rgba(self.params.window_color, alpha=0.5) + ) + ) + window_fill_hover = pg.mkBrush( + color=tuple( + np.round(color_val * 255) + for color_val in mcolors.to_rgba(self.params.window_color, alpha=0.8) + ) + ) + # Handle window regions for interactive selection + windows_to_display = self.handle_windows() + if windows_to_display is not None: + window_items: list[list[pg.LinearRegionItem]] = [ + [] for _ in range(len(windows_to_display)) + ] + else: + window_items: list[list[pg.LinearRegionItem]] = [[]] + + region: pg.LinearRegionItem | None = None + channel_0: pg.PlotItem | None = None + for channel_index, channel in enumerate(trace_select.channel): + channel_tmp = win.addPlot(row=channel_index, col=0) # type: ignore + if channel_index == 0: + channel_0 = channel_tmp + channel_tmp.setXLink(channel_0) + channel_tmp.setLabel( + "left", + f"Channel {trace_select.channel_information.channel_number[channel_index]} " + f"({trace_select.channel_information.unit[channel_index]})", + color=self.params.axis_color, + ) + + if channel_index == len(trace_select.channel) - 1: + channel_tmp.setLabel("bottom", "Time (s)", color=self.params.axis_color) + + channel_tmp.setLabel( + "bottom", + f"Time ({time_array.units.dimensionality.string})", + color=self.params.axis_color, + ) + channel_tmp.setDownsampling(mode="subsample", auto=True) + channel_tmp.setClipToView(True) + channel_box = channel_tmp.getViewBox() + channel_box.setXRange(self.params.xlim[0], self.params.xlim[1]) + + for i in range(channel.data.shape[0]): + qt_color = utils.color_picker_qcolor( + length=channel.data.shape[0], + index=i, + color=self.params.color, + alpha=self.params.alpha, + ) + channel_tmp.plot( + time_array[i], + channel.data[i], + pen=pg.mkPen( + color=qt_color, + ), + ) + if windows_to_display != [(0, 0)]: + for win_index, win_item in enumerate(window_items): + if isinstance(window_items, list) and len(window_items) > 0: + region = pg.LinearRegionItem( + values=windows_to_display[win_index], + pen=pg.mkPen(color=self.params.window_color), + brush=window_fill, + hoverBrush=window_fill_hover, + ) + elif isinstance(window_items, tuple): + region = pg.LinearRegionItem( + values=windows_to_display, + pen=pg.mkPen(color=self.params.window_color), + brush=window_fill, + hoverBrush=window_fill_hover, + ) + else: + continue + win_item.append(region) + region.sigRegionChanged.connect( + make_region_callback(region, win_item, window_index=win_index) + ) + region.setZValue(10 + win_index) + channel_tmp.addItem(region) + + if self.params.average: + channel.channel_average(sweep_subset=self.params.sweep_subset) + channel_tmp.plot( + time_array[0, :], + channel.average.trace, + pen=pg.mkPen(color=self.params.avg_color, width=2), + ) + + return win diff --git a/ephys/classes/plot/plot_window_functions.py b/ephys/classes/plot/plot_window_functions.py new file mode 100644 index 0000000..c2b149c --- /dev/null +++ b/ephys/classes/plot/plot_window_functions.py @@ -0,0 +1,279 @@ +""" +Plotting functions for analyzing and visualizing electrophysiological data. +This module provides classes and functions for plotting traces and summary +measurements using both PyQtGraph and Matplotlib backends. +""" + +from __future__ import annotations + +from copy import deepcopy +from typing import Any, TYPE_CHECKING + +import matplotlib.pyplot as plt +import numpy as np +from matplotlib.axes import Axes +from matplotlib.figure import Figure + +from ephys import utils +from ephys.classes.plot.plot_params import PlotParams +from ephys.classes.class_functions import moving_average + +if TYPE_CHECKING: + from ephys.classes.window_functions import FunctionOutput + from ephys.classes.trace import Trace + + +class FunctionOutputPlot: + """Class for plotting traces and summary measurements with matplotlib.""" + + def __init__(self, function_output: FunctionOutput, **kwargs: Any) -> None: + """ + Initializes the FunctionOutputPlot class with function_output and + additional arguments. + + Args: + function_output (FunctionOutput): The function output to be plotted. + **kwargs: Additional keyword arguments for plot parameters. + """ + # self.trace = function_output.trace + self.function_output = function_output + self.params = PlotParams() + if kwargs: + self.params.update_params(**kwargs) + + +class FunctionOutputPyQt(FunctionOutputPlot): + """Class for plotting traces and summary measurements with PyQtGraph.""" + + def __init__(self, function_output: FunctionOutput, **kwargs: Any) -> None: + """ + Init FunctionOutputPyQt with function_output and params. + + Args: + function_output (FunctionOutput): Output to plot. + **kwargs: Plot params. + """ + super().__init__(function_output, **kwargs) + self.params.update_params(**kwargs) + + def plot( + self, + label_filter: list | str | None = None, + **kwargs: Any, + ) -> None: + """ + Plots the trace and/or summary measurements. + + Args: + label_filter (list | str | None): Labels to filter for plotting. + **kwargs: Additional plot parameters. + + Returns: + None + """ + # TODO: build PyQtGraph plot + if kwargs: + self.params.update_params(**kwargs) + if self.function_output.measurements.size == 0: + print("No measurements to plot") + return None + + if label_filter is None: + label_filter = [] + return None + + +class FunctionOutputMatplotlib(FunctionOutputPlot): + """Class for plotting traces and summary measurements with Matplotlib.""" + + def __init__(self, function_output: FunctionOutput, **kwargs: Any) -> None: + """ + Initializes the FunctionOutputMatplotlib class. + + Args: + function_output (FunctionOutput): The function output to plot. + theme (str, optional): Plot theme ('dark' or 'light'). Default 'dark'. + **kwargs: Additional plot parameters. + """ + super().__init__(function_output, **kwargs) + self.params.update_params(**kwargs) + + def plot( + self, + trace: Trace | None = None, + label_filter: list | str | None = None, + **kwargs: Any, + ) -> tuple[Figure, Axes | np.ndarray] | None: + """ + Plots the trace and/or summary measurements. + + Args: + trace (Trace, optional): Trace object to plot. If None, only summary + measurements are plotted. Default is None. + label_filter (list | str | None, optional): Labels to filter for plotting. + **kwargs: Additional plot parameters. + + Returns: + None + """ + self.params.update_params(**kwargs) + + align_onset: bool = self.params.__dict__.get("align_onset", True) + show: bool = self.params.__dict__.get("show", True) + return_fig: bool = self.params.__dict__.get("return_fig", False) + + fig_out: Figure | None = None + channel_axs_out: Axes | np.ndarray | None = None + + if self.function_output.measurements.size == 0: + print("No measurements to plot") + return None + if label_filter is None: + label_filter = [] + if trace is not None: + trace_select: Trace = trace.subset( + channels=self.function_output.channel, + signal_type=self.function_output.signal_type, + ) + trace_plot_params = deepcopy(self.params.__dict__) + trace_plot_params["show"] = False + trace_plot_params["return_fig"] = True + tmp = trace_select.plot( + **trace_plot_params, + ) + if isinstance(tmp, tuple): + fig, channel_axs = tmp + else: + raise TypeError( + "Trace plot did not return a valid figure or axes object." + ) + else: + fig, channel_axs = plt.subplots( + np.unique(self.function_output.channel).size, 1, sharex=True + ) + channel_count = np.unique(self.function_output.channel).size + unique_labels = np.unique(self.function_output.label) + if align_onset: + x_axis = self.function_output.location + else: + x_axis = self.function_output.time + for color_index, label in enumerate(unique_labels): + # add section to plot on channel by channel basis + for channel_index, channel_number in enumerate( + np.unique(self.function_output.channel) + ): + tmp_axs: Axes | None = None + if channel_count > 1: + if isinstance(channel_axs, np.ndarray): + tmp_axs = channel_axs[channel_index] + else: + if isinstance(channel_axs, Axes): + tmp_axs = channel_axs + if len(label_filter) > 0: + if label not in label_filter: + continue + label_idx = np.where( + (self.function_output.label == label) + & (self.function_output.channel == channel_number) + ) + label_colors = utils.color_picker( + length=len(unique_labels), index=color_index, color="gist_rainbow" + ) + if not align_onset: + y_smooth = moving_average( + self.function_output.measurements[label_idx], + len(label_idx[0]) // 10, + ) + if tmp_axs is not None: + tmp_axs.plot( + x_axis[label_idx], + y_smooth, + color=label_colors, + alpha=0.4, + lw=2, + ) + if tmp_axs is not None: + tmp_axs.plot( + x_axis[label_idx], + self.function_output.measurements[label_idx], + "o", + color=label_colors, + alpha=0.5, + label=label, + ) + if trace is None: + if isinstance(channel_axs, np.ndarray): + for channel_index, channel_number in enumerate( + np.unique(self.function_output.channel) + ): + if channel_index == len(channel_axs) - 1: + channel_axs[channel_index].set_xlabel("Time (s)") + channel_unit = np.unique( + self.function_output.unit[ + self.function_output.channel == channel_number + ] + ) + channel_axs[channel_index].set_ylabel( + f"Channel {int(channel_number)} " f"({channel_unit[0]})" + ) + else: + if isinstance(channel_axs, Axes): + channel_axs.set_xlabel("Time (s)", color=self.params.axis_color) + channel_unit = np.unique(self.function_output.unit) + channel_number = np.unique(self.function_output.channel) + channel_axs.set_ylabel( + f"Channel {int(channel_number)} " f"({channel_unit[0]})", + labelcolor=self.params.axis_color, + ) + if isinstance(channel_axs, np.ndarray): + channel_axs[0].legend(loc="best") + for single_axs in channel_axs: + single_axs.set_facecolor(self.params.bg_color) + single_axs.spines[["top", "right"]].set_visible(False) + single_axs.spines["bottom"].set_color(self.params.axis_color) + single_axs.spines["left"].set_color(self.params.axis_color) + single_axs.xaxis.label.set_color( + self.params.axis_color, + ) + single_axs.yaxis.label.set_color( + self.params.axis_color, + ) + single_axs.tick_params( + axis="x", + colors=self.params.axis_color, + ) + single_axs.tick_params( + axis="y", + colors=self.params.axis_color, + ) + else: + if isinstance(channel_axs, Axes): + channel_axs.set_facecolor(self.params.bg_color) + channel_axs.spines[["top", "right"]].set_visible(False) + channel_axs.spines["bottom"].set_color(self.params.axis_color) + channel_axs.spines["left"].set_color(self.params.axis_color) + channel_axs.xaxis.label.set_color( + self.params.axis_color, + ) + channel_axs.yaxis.label.set_color( + self.params.axis_color, + ) + channel_axs.tick_params( + axis="x", + colors=self.params.axis_color, + ) + channel_axs.tick_params( + axis="y", + colors=self.params.axis_color, + ) + channel_axs.legend(loc="best") + fig.set_facecolor(self.params.bg_color) + + if return_fig: + fig_out = deepcopy(fig) + channel_axs_out = deepcopy(channel_axs) + if show: + plt.show() + if return_fig and fig_out is not None and channel_axs_out is not None: + return fig_out, channel_axs_out + return None diff --git a/ephys/classes/trace.py b/ephys/classes/trace.py index 244ceed..35d0e7f 100644 --- a/ephys/classes/trace.py +++ b/ephys/classes/trace.py @@ -17,14 +17,16 @@ - ephys.classes.class_functions """ -from typing import Any, Optional +from __future__ import annotations +from typing import Any, Optional, TYPE_CHECKING from copy import deepcopy -import numpy as np +from datetime import datetime from uuid import uuid4 -import matplotlib.pyplot as plt -from matplotlib.axes import Axes +import numpy as np from quantities import Quantity -from datetime import datetime +import pyqtgraph as pg +from matplotlib.figure import Figure +from matplotlib.axes import Axes from ephys import utils from ephys.classes.class_functions import ( @@ -36,6 +38,10 @@ from ephys.classes.channels import ChannelInformation from ephys.classes.window_functions import FunctionOutput +if TYPE_CHECKING: + from ephys.classes.plot.plot_trace import TracePlotPyQt, TracePlotMatplotlib + + class Trace: """ Represents a trace object. @@ -50,38 +56,56 @@ class Trace: copy() -> Any: Returns a deep copy of the Trace object. - subset(channels: Any = all channels, can be a list, - signal_type: Any = 'voltage' and 'current', - rec_type: Any = all rec_types) -> Any: - Returns a subset of the Trace object based on the specified channels, signal_type, and - rec_type. - - average_trace(channels: Any = all channels, can be a list, - signal_type: Any = 'voltage' and 'current', can be a list, - rec_type: Any = all rec_types) -> Any: - Returns the average trace of the Trace object based on the specified channels, - signal_type, and rec_type. - - plot(signal_type: str, channels: list, average: bool = False, color: str ='k', - alpha: float = 0.5, avg_color: str = 'r'): - Plots the trace data based on the specified signal_type, channels, and other optional - parameters. + subset( + channels: Any = all channels, can be a list, + signal_type: Any = 'voltage' and 'current', + rec_type: Any = all rec_types + ) -> Any: + Returns a subset of the Trace object based on the specified + channels, signal_type, and rec_type. + + average_trace( + channels: Any = all channels, can be a list, + signal_type: Any = 'voltage' and 'current', can be a list, + rec_type: Any = all rec_types + ) -> Any: + Returns the average trace of the Trace object based on the + specified channels, signal_type, and rec_type. + + plot( + signal_type: str, channels: list, average: bool = False, + color: str = 'k', alpha: float = 0.5, avg_color: str = 'r' + ): + Plots the trace data based on the specified signal_type, + channels, and other optional parameters. """ - def __init__(self, file_path: str, quick_check: bool = True) -> None: - self.file_path = file_path - self.time = Quantity(np.array([]), units="s") - self.sampling_rate = None + def __init__(self, file_path: str = "", quick_check: bool = True) -> None: + self.file_path: str = file_path + self.time: Quantity = Quantity(np.array([]), units="s") + self.sampling_rate: Quantity | None = None self.rec_datetime: Optional[datetime] = None - self.channel = np.array([]) - self.channel_information = ChannelInformation() - self.sweep_count = None - self.object_id = str(uuid4()) - self.window_summary = FunctionOutput() + self.channel: np.ndarray = np.array([]) + self.channel_information: ChannelInformation = ChannelInformation() + self.sweep_count: int | None = None + self.object_id: str = str(uuid4()) + self.window_summary: FunctionOutput = FunctionOutput() + self.window: None | list = None + if self.file_path and len(self.file_path) > 0: + self.load(file_path=self.file_path, quick_check=quick_check) + + def load(self, file_path: str, quick_check: bool = True) -> None: + """ + Load the trace data from a file. + + Args: + file_path (str): The path to the file to load. + quick_check (bool, optional): If True, performs a quick check of the file. + """ if file_path.endswith(".wcp"): - wcp_trace(self, file_path, quick_check) + wcp_trace(trace=self, file_path=file_path, quick_check=quick_check) elif file_path.endswith(".abf"): - abf_trace(self, file_path, quick_check) + abf_trace(trace=self, file_path=file_path, quick_check=quick_check) else: print("File type not supported") if self.sampling_rate is not None: @@ -109,21 +133,21 @@ def subset( Args: channels (Any, optional): Channels to include in the subset. - Defaults to all channels. - signal_type (Any, optional): Types of signal_type to include in the subset. - Defaults to ['voltage', 'current']. + Defaults to all channels. + signal_type (Any, optional): Types of signal_type to include in the + subset. Defaults to ['voltage', 'current']. rec_type (Any, optional): Recording types to include in the subset. - Defaults to ''. + Defaults to ''. clamp_type (Any, optional): Clamp types to include in the subset. - Defaults to None. - channel_groups (Any, optional): Channel groups to include in the subset. - Defaults to None. - sweep_subset (Any, optional): Sweeps to include in the subset. Possible inputs can be - list, arrays or slice(). Defaults to None. - subset_index_only (bool, optional): If True, returns only the subset index. - Defaults to False. + Defaults to None. + channel_groups (Any, optional): Channel groups to include in the + subset. Defaults to None. + sweep_subset (Any, optional): Sweeps to include in the subset. + Possible inputs can be list, arrays or slice(). Defaults to None. + subset_index_only (bool, optional): If True, returns only the subset + index. Defaults to False. in_place (bool, optional): If True, modifies the object in place. - Defaults to False. + Defaults to False. Returns: Any: Subset of the experiment object. @@ -141,15 +165,16 @@ def subset( return self.channel_information return self - sweep_subset = _get_sweep_subset(self.time, sweep_subset) + sweep_subset = _get_sweep_subset(array=self.time, sweep_subset=sweep_subset) if in_place: subset_trace = self else: subset_trace = self.copy() - rec_type_get = utils.string_match( - rec_type, self.channel_information.recording_type + rec_type_get: np.ndarray = utils.string_match( + pattern=rec_type, string_list=self.channel_information.recording_type ) if clamp_type is None: + # By default, include both clamped (True) and unclamped (False) channels clamp_type = np.array([True, False]) clamp_type_get = np.isin(self.channel_information.clamped, np.array(clamp_type)) if channel_groups is None: @@ -180,14 +205,15 @@ def subset( ) if len(combined_index) > 0: - signal_type = self.channel_information.signal_type[combined_index] subset_trace.channel_information.channel_number = ( self.channel_information.channel_number[combined_index] ) subset_trace.channel_information.recording_type = ( self.channel_information.recording_type[combined_index] ) - subset_trace.channel_information.signal_type = signal_type + subset_trace.channel_information.signal_type = ( + self.channel_information.signal_type[combined_index] + ) subset_trace.channel_information.clamped = self.channel_information.clamped[ combined_index ] @@ -215,6 +241,7 @@ def subset( subset_trace.channel_information.unit = np.array([]) if subset_index_only: return subset_trace.channel_information + # sweep count starts at 1 subset_trace.sweep_count = subset_trace.time.shape[0] + 1 return subset_trace @@ -228,19 +255,23 @@ def set_time( """ Set the time axis for the given trace data. - Parameters: - - trace_data (Trace): The trace data object. - - align_to_zero (bool): If True, align the time axis to zero. Default is True. - - cumulative (bool): If True, set the time axis to cumulative. Default is False. - - stimulus_interval (float): The stimulus interval. Default is 0.0 (s). + Args: + align_to_zero (bool): If True, align the time axis to zero. + Default is True. + cumulative (bool): If True, set the time axis to cumulative. + Default is False. + stimulus_interval (float): The stimulus interval. + Default is 0.0 (s). + overwrite_time (bool): If True, overwrite the current time. + Default is True. Returns: - - Trace or None + Trace or None """ - tmp_time = deepcopy(self.time) - time_unit = tmp_time.units - start_time = Quantity(0, time_unit) + tmp_time: Quantity = deepcopy(self.time) + time_unit: Quantity = tmp_time.units + start_time = Quantity(data=0, units=time_unit) if self.sampling_rate is None: raise ValueError( "Sampling rate is not set." @@ -249,20 +280,20 @@ def set_time( sampling_interval = (1 / self.sampling_rate).rescale(time_unit).magnitude for sweep_index, sweep in enumerate(tmp_time): - sweep = Quantity(sweep, time_unit) + sweep = Quantity(data=sweep, units=time_unit) if align_to_zero: - start_time = Quantity(np.min(sweep.magnitude), time_unit) + start_time = Quantity(data=np.min(sweep.magnitude), units=time_unit) if cumulative: if sweep_index > 0: start_time = Quantity( - Quantity( - np.min(sweep.magnitude) + data=Quantity( + data=np.min(sweep.magnitude) - np.max(tmp_time[sweep_index - 1].magnitude), - time_unit, + units=time_unit, ).magnitude - stimulus_interval - sampling_interval, - time_unit, + units=time_unit, ) tmp_time[sweep_index] -= start_time if overwrite_time: @@ -272,17 +303,20 @@ def set_time( def rescale_time(self, time_unit: str = "s") -> None: """ - Rescale the time axis for the given trace data. + Rescale the time axis of the trace to the specified time unit. - Parameters: - - trace_data (Trace): The trace data object. - - time_unit (str): The time unit. Default is 's'. + Args: + time_unit (str): The desired time unit to rescale the time axis to + (default is 's'). Returns: - - None - """ + None - self.time = self.time.rescale(time_unit) + Notes: + This method modifies the `time` attribute of the trace in-place by + converting it to the specified unit. + """ + self.time = self.time.rescale(units=time_unit) def subtract_baseline( self, @@ -297,31 +331,24 @@ def subtract_baseline( """ Subtracts the baseline from the signal within a specified time window. - Parameters: - self : object - The instance of the class containing the signal data. - window : tuple, optional - A tuple specifying the start and end of the time window for baseline - calculation (default is (0, 0.1)). - channels : Any, optional - The channels to be processed. If None, all channels are processed - (default is None). - signal_type : Any, optional - The type of signal to be processed (e.g., 'voltage' or 'current'). - If None, all signal types are processed (default is None). - rec_type : str, optional - The type of recording (default is an empty string). - median : bool, optional - If True, the median value within the window is used as the baseline. - If False, the mean value is used (default is False). - overwrite : bool, optional - If True, the baseline-subtracted data will overwrite the original data. - If False, a copy of the data with the baseline subtracted will be - returned (default is False). + Args: + window (tuple, optional): A tuple specifying the start and end of the time window + for baseline calculation (default is (0, 0.1)). + channels (Any, optional): The channels to be processed. If None, all channels + are processed (default is None). + signal_type (Any, optional): The type of signal to be processed (e.g., 'voltage' + or 'current'). If None, all signal types are processed (default is None). + rec_type (str, optional): The type of recording (default is an empty string). + median (bool, optional): If True, the median value within the window is used as + the baseline. If False, the mean value is used (default is False). + overwrite (bool, optional): If True, the baseline-subtracted data will overwrite + the original data. If False, a copy of the data with the baseline subtracted + will be returned (default is False). + sweep_subset (Any, optional): Sweeps to include in the baseline subtraction. + Defaults to None. Returns: - Any - If overwrite is False, returns a copy of the data with the baseline + Any: If overwrite is False, returns a copy of the data with the baseline subtracted. If overwrite is True, returns None. """ @@ -375,6 +402,82 @@ def subtract_baseline( return trace_copy return None + def get_window( + self, + index: int | None = None, + ) -> tuple | list | None: + """ + Get the current window of the trace. + + Args: + index (int, optional): The index of the window to retrieve. If None, + returns the entire window. Defaults to None. + + Returns: + tuple or list: The current window of the trace. + """ + if self.window is None: + print("No window set for the trace.") + return None + if index is None: + return self.window + if self.window is not None and index >= len(self.window): + raise IndexError("Index out of range for the window list.") + if isinstance(self.window, list): + return self.window[index] + return None + + def add_window( + self, + window: tuple | list, + ) -> None: + """ + Add a window for the trace. + + Args: + window (tuple or list): The window to set for the trace. + """ + if isinstance(window, tuple): + if self.window is None: + self.window = [window] + elif isinstance(self.window, list): + self.window.append(window) + elif isinstance(window, list): + for win in window: + if not isinstance(win, tuple) or len(win) != 2: + raise TypeError("Each window must be a tuple of length 2.") + if self.window is None: + self.window = window + elif isinstance(self.window, list): + self.window.extend(window) + else: + raise ValueError("Window must be a tuple or a list.") + + def remove_window( + self, + index: int = -1, + all_windows: bool = False, + ) -> None: + """ + Remove a window from the trace. + + Args: + index (int, optional): The index of the window to remove. If None, + removes the last window. Defaults to None. + """ + if self.window is None: + return None + if index == -1 and not all_windows: + index = -1 + if all_windows: + self.window = None + if isinstance(self.window, list): + if index is not None and 0 <= index < len(self.window): + del self.window[index] + else: + self.window.pop() + return None + def window_function( self, window: list | None = None, @@ -390,42 +493,38 @@ def window_function( """ Apply a specified function to a subset of channels within given time windows. - Parameters: - ----------- - window : list, optional - List of tuples specifying the start and end of each window. Default is [(0, 0)]. - channels : Any, optional - Channels to be included in the subset. Default is None. - signal_type : Any, optional - Type of signal to be included in the subset. Default is None. - rec_type : str, optional - Type of recording to be included in the subset. Default is an empty string. - function : str, optional - Function to apply to the data. Supported functions are 'mean', 'median', 'max', - 'min', 'min_avg'. Default is 'mean'. - return_output : bool, optional - If True, the function returns the output. Default is False. - plot : bool, optional - If True, the function plots the output. Default is False. + Args: + window (list, optional): List of tuples specifying the start and end of each + window. Default is [(0, 0)]. + channels (Any, optional): Channels to be included in the subset. Default is None. + signal_type (Any, optional): Type of signal to be included in the subset. + Default is None. + rec_type (str, optional): Type of recording to be included in the subset. + Default is an empty string. + function (str, optional): Function to apply to the data. Supported functions are + 'mean', 'median', 'max', 'min', 'min_avg'. Default is 'mean'. + label (str, optional): Label for the output. Default is "". + sweep_subset (Any, optional): Sweeps to include. Default is None. + return_output (bool, optional): If True, the function returns the output. + Default is False. + plot (bool, optional): If True, the function plots the output. Default is False. Returns: - -------- - Any - The output of the applied function if return_output is True, otherwise None. + Any: The output of the applied function if return_output is True, otherwise None. Notes: - ------ - The function updates the `window_summary` attribute of the class with the output. + The function updates the `window_summary` attribute of the class with the output. """ if window is None: window = [(0, 0)] if function not in ["mean", "median", "max", "min", "min_avg"]: print("Function not supported") + return None if not isinstance(window, list): window = [window] sweep_subset = _get_sweep_subset(self.time, sweep_subset) - subset_channels = self.subset( + subset_channels: Trace = self.subset( channels=channels, signal_type=signal_type, rec_type=rec_type, @@ -448,15 +547,21 @@ def window_function( channel_index ], label=label, - unit=subset_channels.channel_information.unit[channel_index], + # unit=subset_channels.channel_information.unit[channel_index], ) if plot: - subset_channels.plot(trace=subset_channels, show=True, window_data=output) + output.plot(trace=subset_channels, plot_trace=True) if return_output: return output - self.window_summary.merge(output) + self.window_summary.merge(window_summary=output) return None + def reset_window_summary(self) -> None: + """ + Deletes the window summary. + """ + self.window_summary = FunctionOutput() + def average_trace( self, channels: Any = None, @@ -464,26 +569,38 @@ def average_trace( rec_type: Any = "", sweep_subset: Any = None, in_place: bool = True, - ) -> Any: + ) -> Any | None: """ - Calculates the average trace for the given channels, signal_type types, and recording type. + Calculates the average trace for the given channels, signal_type types, + and recording type. - Parameters: - - channels (Any): The channels to calculate the average trace for. - If None, uses the first channel type. - - signal_type (Any): The signal_type types to calculate the average trace for. - Defaults to ['voltage', 'current']. - - rec_type (Any): The recording type to calculate the average trace for. + Args: + channels (Any, optional): Channels to include in the average. + Defaults to all channels. + signal_type (Any, optional): Signal types to include in the average. + Defaults to ['voltage', 'current']. + rec_type (Any, optional): Recording type to include in the average. + Defaults to "". + sweep_subset (Any, optional): Sweeps to include in the average. + Defaults to None. + in_place (bool, optional): If True, modifies the object in place. + If False, returns a new Trace object. Returns: - - Any: The average trace object. + None if in_place is True. + Trace: The average trace object if in_place is False. + + Note: + The return type depends on the value of `in_place`. If `in_place` is True, + the method modifies the current object and returns None. If False, it returns + a new Trace object with the averaged data. """ if channels is None: channels = self.channel_information.channel_number if signal_type is None: signal_type = ["voltage", "current"] - sweep_subset = _get_sweep_subset(self.time, sweep_subset) + sweep_subset = _get_sweep_subset(array=self.time, sweep_subset=sweep_subset) if in_place: avg_trace = self else: @@ -503,166 +620,62 @@ def average_trace( return avg_trace def plot( - self, - signal_type: str = "", - channels: np.ndarray = np.array([], dtype=np.int64), - average: bool = False, - color: str = "black", - alpha: float = 0.5, - avg_color: str = "red", - align_onset: bool = True, - sweep_subset: Any = None, - window: tuple = (0, 0), - xlim: tuple = (), - show: bool = True, - return_fig: bool = False, - ) -> None | tuple: + self, backend: str = "matplotlib", **kwargs + ) -> None | pg.GraphicsLayoutWidget | tuple[Figure, Axes]: """ - Plots the traces for the specified channels. + Plots the traces using the specified backend. Args: - signal_type (str): The type of signal_type to use. Must be either 'current' or - 'voltage'. - channels (list, optional): The list of channels to plot. If None, all channels - will be plotted. - Defaults to None. - average (bool, optional): Whether to plot the average trace. - Defaults to False. - color (str, optional): The color of the individual traces. Can be a colormap. - Defaults to 'black'. - alpha (float, optional): The transparency of the individual traces. - Defaults to 0.5. - avg_color (str, optional): The color of the average trace. - Defaults to 'red'. - align_onset (bool, optional): Whether to align the traces on the onset. - Defaults to True. - sweep_subset (Any, optional): The subset of sweeps to plot. - Defaults to None. - window (tuple, optional): The time window to plot. - Defaults to (0, 0). - show (bool, optional): Whether to display the plot. - Defaults to True. - return_fig (bool, optional): Whether to return the figure. - Defaults to False. + backend (str): The plotting backend to use. Options are 'matplotlib' or 'pyqt'. + **kwargs: Additional keyword arguments for the plotting function. Returns: - None or Figure: If show is True, returns None. If return_fig is True, - returns the figure. + None or pg.GraphicsLayoutWidget: If using pyqtgraph, returns the plot widget. """ - - if len(channels) == 0: - channels = self.channel_information.channel_number - sweep_subset = _get_sweep_subset(self.time, sweep_subset) - trace_select = self.subset( - channels=channels, signal_type=signal_type, sweep_subset=sweep_subset - ) - - fig, channel_axs = plt.subplots(len(trace_select.channel), 1, sharex=True) - - if len(trace_select.channel) == 0: - print("No traces found.") - return None - - if align_onset: - time_array = trace_select.set_time( - align_to_zero=True, - cumulative=False, - stimulus_interval=0.0, - overwrite_time=False, - ) - else: - time_array = trace_select.time - - tmp_axs: Axes | None = None - for channel_index, channel in enumerate(trace_select.channel): - if len(trace_select.channel) == 1: - if isinstance(channel_axs, Axes): - tmp_axs = channel_axs - else: - if isinstance(channel_axs, np.ndarray): - if isinstance(channel_axs[channel_index], Axes): - tmp_axs = channel_axs[channel_index] - if tmp_axs is None: - pass - else: - for i in range(channel.data.shape[0]): - tmp_axs.plot( - time_array[i, :], - channel.data[i, :], - color=utils.trace_color( - traces=channel.data, index=i, color=color - ), - alpha=alpha, - ) - if window != (0, 0): - tmp_axs.axvspan( - xmin=window[0], xmax=window[1], color="gray", alpha=0.1 - ) - if average: - channel.channel_average(sweep_subset=sweep_subset) - tmp_axs.plot( - time_array[0, :], channel.average.trace, color=avg_color - ) - tmp_axs.set_ylabel( - f"Channel {trace_select.channel_information.channel_number[channel_index]} " - f"({trace_select.channel_information.unit[channel_index]})" - ) - # tmp_axs.set_ylabel(f'Channel') - if isinstance(tmp_axs, Axes): - tmp_axs.set_xlabel( - f"Time ({trace_select.time.units.dimensionality.string})" - ) - if len(xlim) > 0: - tmp_axs.set_xlim(xlim[0], xlim[1]) - plt.tight_layout() - if show: - plt.show() - # return None - if return_fig: - return fig, channel_axs - return None + # Import here to avoid circular imports + # pylint: disable=import-outside-toplevel + from ephys.classes.plot.plot_trace import TracePlotPyQt, TracePlotMatplotlib + + if backend == "matplotlib": + plot_out = TracePlotMatplotlib(trace=self, **kwargs) + return plot_out.plot() + if backend == "pyqt": + plot_out = TracePlotPyQt(trace=self, **kwargs) + return plot_out.plot() + raise ValueError("Unsupported backend. Use 'matplotlib' or 'pyqt'.") def plot_summary( self, - show_trace: bool = True, + plot_trace: bool = True, align_onset: bool = True, - label_filter: list | str = "", - color="black", - show=True, + label_filter: list | str | None = None, + **kwargs: Any, ) -> None: """ - Plots a summary of the experiment data. - - Parameters: - ----------- - show_trace : bool, optional - If True, includes the trace in the plot. Default is True. - align_onset : bool, optional - If True, aligns the plot on the onset. Default is True. - label_filter : list or str, optional - A filter to apply to the labels. Default is None. - color : str, optional - The color to use for the trace plot. Default is 'black'. + Plot a summary of the experiment data. + + Args: + plot_trace (bool, optional): If True, include the trace in the plot. + Default is True. + align_onset (bool, optional): If True, align the plot on the onset. + Default is True. + label_filter (list or str, optional): Filter to apply to the labels. + Default is None. + **kwargs: Additional keyword arguments for the plotting function. Returns: - -------- - None + None """ - if label_filter == "": + if label_filter == "" or label_filter is None: label_filter = [] if self.window_summary is not None: - if show_trace: - self.window_summary.plot( - trace=self, - align_onset=align_onset, - show=show, - label_filter=label_filter, - color=color, - ) - else: - self.window_summary.plot( - align_onset=align_onset, show=show, label_filter=label_filter - ) + self.window_summary.plot( + trace=self, + align_onset=align_onset, + label_filter=label_filter, + plot_trace=plot_trace, + **kwargs, + ) else: print("No summary data found") diff --git a/ephys/classes/window_functions.py b/ephys/classes/window_functions.py index 8422456..f0c7fc8 100644 --- a/ephys/classes/window_functions.py +++ b/ephys/classes/window_functions.py @@ -1,54 +1,70 @@ +""" +This module defines the FunctionOutput class, which is used to handle the output of +various functions applied to electrophysiological trace data. It includes methods for +appending measurements, merging data, calculating differences and ratios, plotting +results, and converting to dictionary or DataFrame formats. +""" + from __future__ import annotations from copy import deepcopy from typing import Any, TYPE_CHECKING - -import matplotlib.pyplot as plt import numpy as np import pandas as pd -from matplotlib.axes import Axes from quantities import Quantity +from matplotlib.figure import Figure +from matplotlib.axes import Axes from ephys import utils -from ephys.classes.class_functions import ( - _get_time_index, - moving_average, -) +from ephys.classes.plot.plot_params import PlotParams +from ephys.classes.class_functions import _get_time_index if TYPE_CHECKING: from ephys.classes.trace import Trace class FunctionOutput: - """A class to handle the output of various functions applied to electrophysiological trace data. + """ + A class to handle the output of various functions applied to electrophysiological + trace data. + + Args: + function_name (str): The name of the function to be applied to the trace data. Attributes: function_name (str): The name of the function to be applied to the trace data. - measurements (np.ndarray): An array to store the measurements obtained from the trace data. - location (np.ndarray): An array to store the locations corresponding to the measurements. + measurements (np.ndarray): An array to store the measurements obtained from the + trace data. + location (np.ndarray): An array to store the locations corresponding to the + measurements. sweep (np.ndarray): An array to store the sweep indices. channel (np.ndarray): An array to store the channel numbers. - signal_type (np.ndarray): An array to store the types of signals (e.g., current, voltage). + signal_type (np.ndarray): An array to store the types of signals (e.g., current, + voltage). window (np.ndarray): An array to store the time windows used for measurements. label (np.ndarray): An array to store the labels associated with the measurements. - time (np.ndarray): An array to store the time points corresponding to the measurements. + time (np.ndarray): An array to store the time points corresponding to the + measurements. Methods: __init__(self, function_name: str) -> None: Initializes the FunctionOutput object with the given function name. - append(self, trace: Trace, window: tuple, channels: Any = None, signal_type: Any = None, - rec_type: str = '', avg_window_ms: float = 1.0, label: str = '') -> None: + append(self, trace: Trace, window: tuple, channels: Any = None, + signal_type: Any = None, rec_type: str = '', avg_window_ms: float = 1.0, + label: str = '') -> None: Appends measurements and related information from the given trace data to the FunctionOutput object. merge(self, window_summary, remove_duplicates=False) -> None: - Merges the measurements, location, sweep, window, signal_type, and channel attributes - from the given window_summary object into the current object. Optionally removes - duplicates from these attributes after merging. + Merges the measurements, location, sweep, window, signal_type, and channel + attributes from the given window_summary object into the current object. + Optionally removes duplicates from these attributes after merging. - label_diff(self, labels: list = [], new_name: str = '', time_label: str = '') -> None: - Calculates the difference between two sets of measurements and appends the result. + label_diff(self, labels: list = [], new_name: str = '', time_label: str = '') + -> None: + Calculates the difference between two sets of measurements and appends the + result. plot(self, trace: Trace = None, show: bool = True, align_onset: bool = True, label_filter: list | str = [], color='black') -> None: @@ -59,12 +75,15 @@ class FunctionOutput: to_dataframe(self) -> pd.DataFrame: Converts the experiment object to a pandas DataFrame. + delete_label(self, label: str | list) -> None: Deletes a label from the measurements. """ def __init__(self, function_name: str = "") -> None: + self.trace: Trace | None = None self.function_name = function_name + self.function = np.array([], dtype=str) self.measurements = np.array([]) self.location = np.array([]) self.sweep = np.array([]) @@ -73,7 +92,7 @@ def __init__(self, function_name: str = "") -> None: self.window = np.ndarray(dtype=object, shape=(0, 2)) self.label = np.array([]) self.time = np.array([]) - self.unit = [] + self.unit = np.array([]) def append( self, @@ -84,33 +103,27 @@ def append( rec_type: str = "", avg_window_ms: float = 1.0, label: str = "", - unit: str = "", ) -> None: """ Appends measurements from a given trace within a specified time window. - Parameters: - ----------- - trace : Trace - The trace object containing the data to be analyzed. - window : tuple - A tuple specifying the start and end times of the window for measurement. - channels : Any, optional - The channels to be included in the subset of the trace. Default is None. - signal_type : Any, optional - The type of signal to be included in the subset of the trace. Default is None. - rec_type : str, optional - The recording type to be included in the subset of the trace. Default is an - empty string. - avg_window_ms : float, optional - The averaging window size in milliseconds for the 'min_avg' function. Default - is 1.0 ms. - label : str, optional - A label to be associated with the measurements. Default is an empty string. + Args: + trace (Trace): The trace object containing the data to be analyzed. + window (tuple): A tuple specifying the start and end times of the window + for measurement. + channels (Any, optional): The channels to be included in the subset of the + trace. Default is None. + signal_type (Any, optional): The type of signal to be included in the + subset of the trace. Default is None. + rec_type (str, optional): The recording type to be included in the subset + of the trace. Default is an empty string. + avg_window_ms (float, optional): The averaging window size in milliseconds + for the 'min_avg' function. Default is 1.0 ms. + label (str, optional): A label to be associated with the measurements. + Default is an empty string. Returns: - -------- - None + None """ trace_subset = trace.subset( channels=channels, signal_type=signal_type, rec_type=rec_type @@ -122,11 +135,11 @@ def append( stimulus_interval=0.0, overwrite_time=True, ) - for channel_index, _ in enumerate( + self.trace = trace + for channel_index, channel in enumerate( trace_subset.channel_information.channel_number ): tmp_location = np.array([]) - array_index = trace_subset.channel_information.array_index[channel_index] time_window_size = Quantity(avg_window_ms, "ms") channel_signal_type = trace_subset.channel_information.signal_type[ channel_index @@ -158,7 +171,7 @@ def append( ), trace_subset.time.units, ) - if channel_signal_type == "voltage" or channel_signal_type == "current": + if channel_signal_type in ["voltage", "current"]: windowed_trace = np.array( [ row[start:end] @@ -169,29 +182,6 @@ def append( ) ] ) - - # if channel_signal_type == "current": - # windowed_trace = np.array( - # [ - # row[start:end] - # for row, start, end in zip( - # trace_subset.current[array_index, :, :], - # window_start_index, - # window_end_index, - # ) - # ] - # ) - # elif channel_signal_type == "voltage": - # windowed_trace = np.array( - # [ - # row[start:end] - # for row, start, end in zip( - # trace_subset.voltage[array_index, :, :], - # window_start_index, - # window_end_index, - # ) - # ] - # ) else: print("Signal type not found") return None @@ -246,7 +236,7 @@ def append( [ row[(start + window_index) : (end - window_index)] for row, start, end in zip( - trace_subset.voltage[array_index, :, :], + trace_subset.channel[channel_index].data.magnitude, window_start_index, window_end_index, ) @@ -308,6 +298,13 @@ def append( ] ).mean(axis=1), ) + if label == "": + # find the function_name in the self.function + label = self.function_name + + self.function = np.append( + self.function, np.repeat(self.function_name, sweep_dim) + ) self.location = np.append(self.location, tmp_location) self.sweep = np.append(self.sweep, np.arange(1, sweep_dim + 1)) self.window = np.vstack((self.window, np.tile(window, (sweep_dim, 1)))) @@ -322,14 +319,15 @@ def append( self.channel = np.append( self.channel, np.repeat( - trace_subset.channel_information.channel_number[channel_index], + channel, sweep_dim, ), ) - self.unit.append( + self.unit = np.append( + self.unit, np.repeat( trace_subset.channel_information.unit[channel_index], sweep_dim - ) + ), ) self.time = np.append( @@ -349,15 +347,15 @@ def append( def merge(self, window_summary, remove_duplicates=False) -> None: """ - Merges the measurements, location, sweep, window, signal_type, and channel attributes - from the given window_summary object into the current object. Optionally removes duplicates - from these attributes after merging. + Merges the measurements, location, sweep, window, signal_type, and channel + attributes from the given window_summary object into the current object. + Optionally removes duplicates from these attributes after merging. Args: - window_summary (object): An object containing measurements, location, sweep, window, - signal_type, and channel attributes to be merged. - remove_duplicates (bool, optional): If True, removes duplicate entries from the merged - attributes. Defaults to True. + window_summary (object): An object containing measurements, location, + sweep, window, signal_type, and channel attributes to be merged. + remove_duplicates (bool, optional): If True, removes duplicate entries + from the merged attributes. Defaults to True. Returns: None @@ -369,8 +367,14 @@ def merge(self, window_summary, remove_duplicates=False) -> None: self.window = np.vstack((self.window, window_summary.window)) self.signal_type = np.append(self.signal_type, window_summary.signal_type) self.channel = np.append(self.channel, window_summary.channel) - self.label = np.append(self.label, window_summary.label) + self.label = np.append( + self.label, utils.unique_label_name(self.label, window_summary.label) + ) + self.function = np.append(self.function, window_summary.function) + self.unit = np.append(self.unit, window_summary.unit) self.time = np.append(self.time, window_summary.time) + self.trace = window_summary.trace + if remove_duplicates: np.unique(self.measurements) self.measurements = np.unique(self.measurements) @@ -380,6 +384,8 @@ def merge(self, window_summary, remove_duplicates=False) -> None: self.signal_type = np.unique(self.signal_type) self.channel = np.unique(self.channel) self.label = np.unique(self.label) + self.function = np.unique(self.function) + self.unit = np.unique(self.unit) def label_diff( self, labels: list | None = None, new_name: str = "", time_label: str = "" @@ -387,13 +393,15 @@ def label_diff( """ Calculate the difference between two sets of measurements and append the result. - Parameters: - labels (list): Labels whose measurements will be used to calculate the difference. - new_name (str): Label name for the new set of measurements. - time_label (str): Label to identify the time points for the new measurements. + Args: + labels (list): Labels whose measurements will be used to calculate the + difference. + new_name (str): Label name for the new set of measurements. + time_label (str): Label to identify the time points for the new + measurements. Returns: - None + None """ if labels is None: @@ -406,6 +414,8 @@ def label_diff( label_index_2 = np.where(self.label == labels[1]) time_label_index = np.where(self.label == time_label) diff = self.measurements[label_index_1] - self.measurements[label_index_2] + self.function = np.append(self.function, np.repeat("diff", len(diff))) + self.measurements = np.append(self.measurements, diff) self.location = np.append(self.location, self.location[time_label_index]) self.sweep = np.append(self.sweep, self.sweep[time_label_index]) @@ -425,12 +435,15 @@ def label_ratio( ) -> None: """ Calculate the ratio between two sets of measurements and append the result. - Parameters: - labels (list): Labels whose measurements will be used to calculate the ratio. - new_name (str): Label name for the new set of measurements. - time_label (str): Label to identify the time points for the new measurements. + + Args: + labels (list): Labels whose measurements will be used to calculate the ratio. + new_name (str): Label name for the new set of measurements. + time_label (str): Label to identify the time points for the new + measurements. + Returns: - None + None """ if labels is None: @@ -458,114 +471,83 @@ def label_ratio( def plot( self, + plot_trace: bool = True, trace: Trace | None = None, - show: bool = True, - align_onset: bool = True, label_filter: list | str | None = None, - color="black", - ) -> None: + backend: str = "matplotlib", # remove default after setting up pyqt + **kwargs, + ) -> None | tuple[Figure, Axes | np.ndarray] | None: """ - Plots the trace and/or summary measurements. + Plot the trace and/or summary measurements. - Parameters: - trace (Trace, optional): The trace object to be plotted. If None, only the summary - measurements are plotted. Default is None. - show (bool, optional): If True, the plot will be displayed. Default is True. - summary_only (bool, optional): If True, only the summary measurements will be - plotted. Default is True. + Args: + plot_trace (bool, optional): If True, plot the trace data. + Default is True. + trace (Trace or None, optional): The trace object to be plotted. + If None, uses the trace stored in the FunctionOutput object. + label_filter (list, str, or None, optional): Labels to filter the + measurements for plotting. + backend (str, optional): The plotting backend to use + ('matplotlib' or 'pyqt'). Default is 'matplotlib'. + **kwargs: Additional keyword arguments for plotting. Returns: - None + FunctionOutputPyQt or FunctionOutputMatplotlib: The plot output object. + + Raises: + ValueError: If an unsupported backend is specified. + TypeError: If the trace is not an instance of Trace when plot_trace is True. """ - if self.measurements.size == 0: - print("No measurements to plot") - return None - if label_filter is None: - label_filter = [] - _, channel_axs = None, None - if trace is not None: - # self.channel = np.unique(np.array(self.channel)) - trace_select = trace.subset( - channels=self.channel, signal_type=self.signal_type + + # pylint:disable=import-outside-toplevel + from ephys.classes.trace import Trace + + plot_params = PlotParams() + plot_params.update_params(**kwargs) + if plot_trace: + if trace is not None: + if trace is None: + trace = self.trace + if not isinstance(trace, Trace): + raise TypeError("trace must be an instance of Trace.") + else: + trace = None + if backend == "matplotlib": + # pylint:disable=import-outside-toplevel + from ephys.classes.plot.plot_window_functions import ( + FunctionOutputMatplotlib, ) - # trace_select.plot(show=False, align_onset=align_onset, color=color) - _, channel_axs = trace_select.plot( - show=False, align_onset=align_onset, color=color, return_fig=True + + plot_output = FunctionOutputMatplotlib( + function_output=self, **plot_params.__dict__ + ) + elif backend == "pyqt": + # pylint:disable=import-outside-toplevel + from ephys.classes.plot.plot_window_functions import FunctionOutputPyQt + + plot_output = FunctionOutputPyQt( + function_output=self, **plot_params.__dict__ ) else: - _, channel_axs = plt.subplots(np.unique(self.channel).size, 1, sharex=True) - channel_count = np.unique(self.channel).size - unique_labels = np.unique(self.label) - if align_onset: - x_axis = self.location - else: - x_axis = self.time - for color_index, label in enumerate(unique_labels): - # add section to plot on channel by channel basis - for channel_index, channel_number in enumerate(np.unique(self.channel)): - tmp_axs: Axes | None = None - if channel_count > 1: - if isinstance(channel_axs, np.ndarray): - tmp_axs = channel_axs[channel_index] - else: - if isinstance(channel_axs, Axes): - tmp_axs = channel_axs - if len(label_filter) > 0: - if label not in label_filter: - continue - label_idx = np.where( - (self.label == label) & (self.channel == channel_number) - ) - label_colors = utils.color_picker( - length=len(unique_labels), index=color_index, color="gist_rainbow" - ) - if not align_onset: - y_smooth = moving_average( - self.measurements[label_idx], len(label_idx[0]) // 10 - ) - if tmp_axs is not None: - tmp_axs.plot( - x_axis[label_idx], - y_smooth, - color=label_colors, - alpha=0.4, - lw=2, - ) - if tmp_axs is not None: - tmp_axs.plot( - x_axis[label_idx], - self.measurements[label_idx], - "o", - color=label_colors, - alpha=0.5, - label=label, - ) - # TODO: add y and x labels to the plot when no trace is provided. - if (unique_labels.size != 1) and (unique_labels[0] != ""): - if isinstance(channel_axs, np.ndarray): - channel_axs[0].legend(loc="best") - else: - if isinstance(channel_axs, Axes): - channel_axs.legend(loc="best") + raise ValueError( + f"Unsupported backend: {backend}. Use 'matplotlib' or 'pyqt'." + ) - if show: - plt.show() + plot_output.plot( + trace=trace, + label_filter=label_filter, + ) def to_dict(self): """ - Convert the experiment object to a dictionary representation. + Convert the FunctionOutput object to a dictionary representation. Returns: - dict: A dictionary containing the following keys: - - 'measurements': The measurements associated with the experiment. - - 'location': The location of the experiment. - - 'sweep': The sweep information of the experiment. - - 'window': The window information of the experiment. - - 'signal_type': The type of signal used in the experiment. - - 'channel': The channel information of the experiment. + dict: A dictionary containing the function output attributes. """ return { + "function": self.function, "measurements": self.measurements, "location": self.location, "sweep": self.sweep, @@ -573,16 +555,16 @@ def to_dict(self): "signal_type": self.signal_type, "channel": self.channel, "label": self.label, + "unit": self.unit, "time": self.time, } def to_dataframe(self): """ - Convert the experiment object to a pandas DataFrame. + Convert the FunctionOutput object to a pandas DataFrame. Returns: - pandas.DataFrame: A DataFrame containing the measurements, location, - sweep, window, signal_type, and channel information. + pandas.DataFrame: A DataFrame containing the function output attributes. """ tmp_dictionary = self.to_dict() @@ -590,7 +572,7 @@ def to_dataframe(self): tmp_dictionary["window"] = window_tmp return pd.DataFrame(tmp_dictionary) - def delete_label(self, label: str | list) -> None: + def delete_label(self, label: str | list | None = None) -> None: """ Delete a label from the measurements. @@ -600,10 +582,23 @@ def delete_label(self, label: str | list) -> None: Returns: None """ + if label is None: + self.function = np.array([], dtype=str) + self.measurements = np.array([]) + self.location = np.array([]) + self.sweep = np.array([]) + self.channel = np.array([]) + self.signal_type = np.array([]) + self.window = np.ndarray(dtype=object, shape=(0, 2)) + self.label = np.array([]) + self.time = np.array([]) + self.unit = np.array([]) + return None if isinstance(label, str): label = [label] for label_i in label: label_index = np.where(self.label == label_i) + self.function = np.delete(self.function, label_index) self.measurements = np.delete(self.measurements, label_index) self.location = np.delete(self.location, label_index) self.sweep = np.delete(self.sweep, label_index) @@ -612,3 +607,5 @@ def delete_label(self, label: str | list) -> None: self.channel = np.delete(self.channel, label_index) self.label = np.delete(self.label, label_index) self.time = np.delete(self.time, label_index) + self.unit = np.delete(self.unit, label_index) + return None diff --git a/ephys/utils.py b/ephys/utils.py index b8224bf..c48935b 100644 --- a/ephys/utils.py +++ b/ephys/utils.py @@ -3,47 +3,78 @@ It includes functions to check if elements in a list or numpy array match a given pattern. """ + +import numpy as np from typing import Any -from re import compile as compile_regex +from re import compile as compile_regex, match from matplotlib import colormaps -from matplotlib.colors import is_color_like -import numpy as np +from matplotlib.colors import is_color_like, to_rgb +from pyqtgraph import mkColor +from pyqtgraph.Qt.QtGui import QColor -def color_picker(length: int, index: int, color: str = "black") -> str | np.ndarray: +def color_picker( + length: int, index: int, color: str = "black", alpha: float = 1.0 +) -> str | np.ndarray: """ Selects a color from a colormap or validates a given color string. - Parameters: - length (int): Number of colors in the colormap. - index (int): Index of the color to select. - color (str): Name of the colormap or a color string. Defaults to 'black'. + Args: + length (int): Number of colors in the colormap. + index (int): Index of the color to select. + color (str): Name of the colormap or a color string. Defaults to 'black'. + alpha (float): Alpha value for the color. Defaults to 1.0. Returns: - str | np.ndarray: The selected color. + str | np.ndarray: The selected color. Notes: - - Defaults to 'viridis' colormap if the color is invalid. + - Defaults to 'viridis' colormap if the color is invalid. """ + if not 0 <= alpha <= 1: + raise ValueError("Alpha value must be between 0 and 1") if color in colormaps.keys(): color_map = colormaps[color] - color = color_map(np.linspace(0, 1, length))[index] + color = color_map(np.linspace(0, 1, length), alpha)[index] elif is_color_like(color): pass else: print("Invalid color. Default to 'viridis'.") color_map = colormaps["viridis"] - color = color_map(np.linspace(0, 1, length))[index] + color = color_map(np.linspace(0, 1, length), alpha)[index] return color +def color_picker_qcolor( + length: int, index: int, color: str = "black", alpha: float = 1.0 +) -> QColor: + """ + Selects a color from a colormap or validates a given color string and returns + it as a QColor object. + Args: + length (int): Number of colors in the colormap. + index (int): Index of the color to select. + color (str): Name of the colormap or a color string. Defaults to 'black'. + alpha (float): Alpha value for the color. Defaults to 1.0. + Returns: + QColor: The selected color as a QColor object. + """ + selected_color: str | np.ndarray = color_picker(length, index, color, alpha) + rgba_val: tuple = (0, 0, 0, 255) # Default RGBA value + if isinstance(selected_color, np.ndarray): + rgba_val = tuple(int(c * 255) for c in selected_color) + elif isinstance(selected_color, str): + rgba_val = tuple(int(c * 255) for c in to_rgb(selected_color) + (alpha,)) + return mkColor(rgba_val) + + def trace_color( traces: np.ndarray, index: int, color: str = "black" ) -> str | np.ndarray: """ Returns the color for a specific trace in a given colormap. - Parameters: + Args: traces (np.ndarray): The array of traces. index (int): The index of the trace. color (str): The name of the colormap or a specific color. Default is 'black'. @@ -58,15 +89,15 @@ def string_match(pattern: Any, string_list: Any) -> np.ndarray: """ Check if the given pattern matches any element in the input list. - Parameters: - pattern (str, list, or numpy array): The pattern to match against. - It can be a string, a list of strings, or a numpy array of strings. - input (str, list, or numpy array): The input list to check for matches. - It can be a string, a list of strings, or a numpy array of strings. + Args: + pattern (str, list, or numpy array): The pattern to match against. + It can be a string, a list of strings, or a numpy array of strings. + string_list (str, list, or numpy array): The input list to check for matches. + It can be a string, a list of strings, or a numpy array of strings. Returns: - np.ndarray: A boolean numpy array indicating whether each element in the - string_input list matches the pattern. + np.ndarray: A boolean numpy array indicating whether each element in the + string_list matches the pattern. """ def single_match(pattern_string: str, input_string: str): @@ -87,3 +118,44 @@ def single_match(pattern_string: str, input_string: str): raise ValueError("Input list must be a str, list or numpy array of strings.") pattern_string = "|".join(pattern) return np.array([single_match(pattern_string, i) for i in string_list]) + + +def check_label_name(label_array: np.ndarray, label_name: str) -> str: + """ + Checks if label_name exists in label_array and returns a unique label by + appending a numeric suffix if needed. + + Args: + label_array (np.ndarray): Array of existing label names. + label_name (str): The label name to check for uniqueness. + + Returns: + str: A unique label name. + """ + if label_array.size == 0 or label_name not in label_array: + return label_name + + regex_pattern = rf"^{label_name}_(\d+)$" + suffixes = [int(m.group(1)) for s in label_array if (m := match(regex_pattern, s))] + next_suffix = max(suffixes, default=0) + 1 + return f"{label_name}_{next_suffix}" + + +def unique_label_name(label_array: np.ndarray, label_name: np.ndarray) -> np.ndarray: + """ + Returns a unique label based on label_name, appending a numeric suffix if needed. + + Args: + label_array (np.ndarray): Array of existing label names. + label_name (np.ndarray): Array of label names to make unique. + + Returns: + np.ndarray: Array of unique label names. + """ + if isinstance(label_name, np.ndarray): + output_names = label_name.copy() + for name in np.unique(label_name): + if name in label_array: + output_names[output_names == name] = check_label_name(label_array, name) + return output_names + raise TypeError("label_name must be a string or numpy array of strings.") diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000..593c90b --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,82 @@ +from __future__ import annotations +from typing import Any +import numpy as np +import re as re + + +def string_match(pattern: Any, input_list: Any) -> np.ndarray: + """ + Check if elements in the input list match the given pattern. + + Args: + pattern (str or list): The pattern to match against. It can be a single string or a list of strings. + input_list (list or np.array): The list of elements to check for a match. It can be a list or a numpy array. + + Returns: + np.array: A boolean numpy array indicating whether each element in the input list matches the pattern. + """ + if type(pattern) == str: + pattern = [pattern] + if type(pattern) != list: + raise TypeError("Pattern must be a string or list of strings.") + return None + if type(input_list) == list: + input_list = np.array(input_list) + if type(input_list) != np.ndarray: + raise TypeError("Input list must be a list or numpy array.") + return None + pattern_string = '|'.join(pattern) + r = re.compile(pattern_string) + vmatch = np.vectorize(lambda x:bool(r.match(x))) + return vmatch(input_list) + +def test_string_match(): + # Test case 1: Single pattern, single element match + pattern = 'apple' + input_list = ['apple', 'banana', 'cherry'] + expected_output = np.array([True, False, False]) + assert np.array_equal(string_match(pattern, input_list), expected_output) + + # Test case 2: Single pattern, multiple element match + pattern = 'a' + input_list = ['apple', 'banana', 'cherry'] + expected_output = np.array([True, True, False]) + assert np.array_equal(string_match(pattern, input_list), expected_output) + + # Test case 3: Multiple patterns, single element match + pattern = ['apple', 'banana'] + input_list = ['apple', 'banana', 'cherry'] + expected_output = np.array([True, True, False]) + assert np.array_equal(string_match(pattern, input_list), expected_output) + + # Test case 4: Multiple patterns, multiple element match + pattern = ['a', 'b'] + input_list = ['apple', 'banana', 'cherry'] + expected_output = np.array([True, True, False]) + assert np.array_equal(string_match(pattern, input_list), expected_output) + + # Test case 5: Empty pattern, empty input list + pattern = [] + input_list = [] + expected_output = np.array([]) + assert np.array_equal(string_match(pattern, input_list), expected_output) + + # Test case 6: Invalid pattern type + pattern = 123 + input_list = ['apple', 'banana', 'cherry'] + try: + string_match(pattern, input_list) + except ValueError as e: + assert str(e) == "Pattern must be a string or list of strings." + + # Test case 7: Invalid input list type + pattern = 'apple' + input_list = 123 + try: + string_match(pattern, input_list) + except ValueError as e: + assert str(e) == "Input list must be a list or numpy array." + + print("All test cases pass") + +test_string_match() \ No newline at end of file