Skip to content

Commit 1927d12

Browse files
pyqt window function plot
1 parent 59acfb9 commit 1927d12

File tree

1 file changed

+111
-8
lines changed

1 file changed

+111
-8
lines changed

ephys/classes/plot/plot_window_functions.py

Lines changed: 111 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,20 +13,26 @@
1313
import numpy as np
1414
from matplotlib.axes import Axes
1515
from matplotlib.figure import Figure
16+
import pyqtgraph as pg
17+
import pandas as pd
1618

1719
from ephys import utils
1820
from ephys.classes.plot.plot_params import PlotParams
21+
1922
from ephys.classes.class_functions import moving_average
2023

2124
if TYPE_CHECKING:
2225
from ephys.classes.window_functions import FunctionOutput
26+
from ephys.classes.plot.plot_trace import TracePlotPyQt
2327
from ephys.classes.trace import Trace
2428

2529

2630
class FunctionOutputPlot:
2731
"""Class for plotting traces and summary measurements with matplotlib."""
2832

29-
def __init__(self, function_output: FunctionOutput, **kwargs: Any) -> None:
33+
def __init__(
34+
self, function_output: FunctionOutput, trace: Trace | None = None, **kwargs: Any
35+
) -> None:
3036
"""
3137
Initializes the FunctionOutputPlot class with function_output and
3238
additional arguments.
@@ -58,9 +64,10 @@ def __init__(self, function_output: FunctionOutput, **kwargs: Any) -> None:
5864

5965
def plot(
6066
self,
67+
trace: Trace | None = None,
6168
label_filter: list | str | None = None,
6269
**kwargs: Any,
63-
) -> None:
70+
) -> FunctionOutputPyQt | None:
6471
"""
6572
Plots the trace and/or summary measurements.
6673
@@ -71,16 +78,112 @@ def plot(
7178
Returns:
7279
None
7380
"""
74-
# TODO: build PyQtGraph plot
75-
if kwargs:
76-
self.params.update_params(**kwargs)
81+
from ephys.classes.trace import Trace
82+
83+
self.params.update_params(**kwargs)
84+
85+
align_onset: bool = self.params.__dict__.get("align_onset", True)
86+
show: bool = self.params.__dict__.get("show", True)
87+
trace_channels: list[pg.PlotItem] | None = None
88+
# trace = self.function_output.trace
7789
if self.function_output.measurements.size == 0:
7890
print("No measurements to plot")
7991
return None
92+
window_groups = self.function_output.to_dataframe().groupby(["unit", "channel"])
93+
channel_plot_dict: dict[str, pg.PlotItem] = {}
94+
self.win = pg.GraphicsLayoutWidget(show=show, title="Summary Plot")
8095

8196
if label_filter is None:
8297
label_filter = []
83-
return None
98+
if isinstance(trace, Trace):
99+
from ephys.classes.plot.plot_trace import TracePlotPyQt
100+
101+
trace_select: Trace = trace.subset(
102+
channels=self.function_output.channel,
103+
signal_type=self.function_output.signal_type,
104+
)
105+
trace_plot_params = deepcopy(self.params.__dict__)
106+
trace_plot_params["show"] = False
107+
trace_plot_params["return_fig"] = True
108+
trace_plot = trace_select.plot(backend="pyqt", **trace_plot_params)
109+
if isinstance(trace_plot, TracePlotPyQt):
110+
x_range = trace_plot.params.xlim
111+
trace_channels = [
112+
item
113+
for item in trace_plot.win.items()
114+
if isinstance(item, pg.PlotItem)
115+
]
116+
self.win: pg.GraphicsLayoutWidget = trace_plot.win
117+
channel_plot_dict = {
118+
channel.getAxis("left").labelText: channel
119+
for channel in trace_channels
120+
}
121+
self.win.setBackground(self.params.bg_color)
122+
123+
channel_0: pg.PlotItem | None = None
124+
from ephys.classes.class_functions import moving_average
125+
126+
if self.params.align_onset:
127+
x_axis = self.function_output.location
128+
else:
129+
if trace is not None:
130+
x_axis = self.function_output.location + np.array(
131+
[
132+
trace.time[int(sweep - 1), 0]
133+
for sweep in self.function_output.sweep
134+
]
135+
)
136+
else:
137+
x_axis = self.function_output.time
138+
x_range = (np.min(x_axis), np.max(x_axis))
139+
for channel_index, ((unit, channel), group) in enumerate(window_groups):
140+
if isinstance(trace_channels, list):
141+
channel_plot = channel_plot_dict[f"Channel {int(channel)} " f"({unit})"]
142+
else:
143+
channel_plot: pg.PlotItem = self.win.addPlot(row=channel_index, col=0) # type: ignore
144+
if channel_index == 0:
145+
channel_0 = channel_plot
146+
channel_plot.addLegend()
147+
channel_plot.setXLink(channel_0) # type: ignore
148+
channel_box = channel_plot.getViewBox()
149+
channel_box.setXRange(x_range[0], x_range[1]) # type: ignore
150+
subgroups = group.groupby("label")
151+
label_count = len(subgroups)
152+
for i, (label, subgroup) in enumerate(subgroups):
153+
label_colors = utils.color_picker_qcolor(
154+
length=label_count, index=i, color="gist_rainbow", alpha=0.5
155+
)
156+
if not align_onset:
157+
y_smooth = moving_average(
158+
subgroup["measurements"].to_numpy(),
159+
subgroup.index.size // 10,
160+
)
161+
channel_plot.plot(
162+
x_axis[subgroup.index],
163+
y_smooth,
164+
pen=label_colors,
165+
alpha=0.4,
166+
width=2.0,
167+
)
168+
channel_plot.setLabel("left", f"Channel {int(channel)} " f"({unit})")
169+
channel_plot.scatterPlot(
170+
x_axis[subgroup.index],
171+
subgroup["measurements"].values,
172+
pen=label_colors,
173+
symbol="o",
174+
symbolSize=10,
175+
symbolPen="w",
176+
symbolBrush=label_colors,
177+
name=f"{label}",
178+
)
179+
if show:
180+
self.show()
181+
return self
182+
183+
def show(self) -> pg.GraphicsLayoutWidget | None:
184+
if self.win is not None:
185+
self.win.show()
186+
return self.win
84187

85188

86189
class FunctionOutputMatplotlib(FunctionOutputPlot):
@@ -157,7 +260,7 @@ def plot(
157260
channel_count = np.unique(self.function_output.channel).size
158261
unique_labels = np.unique(self.function_output.label)
159262
if align_onset:
160-
x_axis = self.function_output.location
263+
x_axis = self.function_output.location.copy()
161264
else:
162265
if trace is not None:
163266
x_axis = self.function_output.location + np.array(
@@ -167,7 +270,7 @@ def plot(
167270
]
168271
)
169272
else:
170-
x_axis = self.function_output.time
273+
x_axis = self.function_output.time.copy()
171274
for color_index, label in enumerate(unique_labels):
172275
# add section to plot on channel by channel basis
173276
for channel_index, channel_number in enumerate(

0 commit comments

Comments
 (0)