1313import numpy as np
1414from matplotlib .axes import Axes
1515from matplotlib .figure import Figure
16+ import pyqtgraph as pg
17+ import pandas as pd
1618
1719from ephys import utils
1820from ephys .classes .plot .plot_params import PlotParams
21+
1922from ephys .classes .class_functions import moving_average
2023
2124if TYPE_CHECKING :
2225 from ephys .classes .window_functions import FunctionOutput
26+ from ephys .classes .plot .plot_trace import TracePlotPyQt
2327 from ephys .classes .trace import Trace
2428
2529
2630class FunctionOutputPlot :
2731 """Class for plotting traces and summary measurements with matplotlib."""
2832
29- def __init__ (self , function_output : FunctionOutput , ** kwargs : Any ) -> None :
33+ def __init__ (
34+ self , function_output : FunctionOutput , trace : Trace | None = None , ** kwargs : Any
35+ ) -> None :
3036 """
3137 Initializes the FunctionOutputPlot class with function_output and
3238 additional arguments.
@@ -58,9 +64,10 @@ def __init__(self, function_output: FunctionOutput, **kwargs: Any) -> None:
5864
5965 def plot (
6066 self ,
67+ trace : Trace | None = None ,
6168 label_filter : list | str | None = None ,
6269 ** kwargs : Any ,
63- ) -> None :
70+ ) -> FunctionOutputPyQt | None :
6471 """
6572 Plots the trace and/or summary measurements.
6673
@@ -71,16 +78,112 @@ def plot(
7178 Returns:
7279 None
7380 """
74- # TODO: build PyQtGraph plot
75- if kwargs :
76- self .params .update_params (** kwargs )
81+ from ephys .classes .trace import Trace
82+
83+ self .params .update_params (** kwargs )
84+
85+ align_onset : bool = self .params .__dict__ .get ("align_onset" , True )
86+ show : bool = self .params .__dict__ .get ("show" , True )
87+ trace_channels : list [pg .PlotItem ] | None = None
88+ # trace = self.function_output.trace
7789 if self .function_output .measurements .size == 0 :
7890 print ("No measurements to plot" )
7991 return None
92+ window_groups = self .function_output .to_dataframe ().groupby (["unit" , "channel" ])
93+ channel_plot_dict : dict [str , pg .PlotItem ] = {}
94+ self .win = pg .GraphicsLayoutWidget (show = show , title = "Summary Plot" )
8095
8196 if label_filter is None :
8297 label_filter = []
83- return None
98+ if isinstance (trace , Trace ):
99+ from ephys .classes .plot .plot_trace import TracePlotPyQt
100+
101+ trace_select : Trace = trace .subset (
102+ channels = self .function_output .channel ,
103+ signal_type = self .function_output .signal_type ,
104+ )
105+ trace_plot_params = deepcopy (self .params .__dict__ )
106+ trace_plot_params ["show" ] = False
107+ trace_plot_params ["return_fig" ] = True
108+ trace_plot = trace_select .plot (backend = "pyqt" , ** trace_plot_params )
109+ if isinstance (trace_plot , TracePlotPyQt ):
110+ x_range = trace_plot .params .xlim
111+ trace_channels = [
112+ item
113+ for item in trace_plot .win .items ()
114+ if isinstance (item , pg .PlotItem )
115+ ]
116+ self .win : pg .GraphicsLayoutWidget = trace_plot .win
117+ channel_plot_dict = {
118+ channel .getAxis ("left" ).labelText : channel
119+ for channel in trace_channels
120+ }
121+ self .win .setBackground (self .params .bg_color )
122+
123+ channel_0 : pg .PlotItem | None = None
124+ from ephys .classes .class_functions import moving_average
125+
126+ if self .params .align_onset :
127+ x_axis = self .function_output .location
128+ else :
129+ if trace is not None :
130+ x_axis = self .function_output .location + np .array (
131+ [
132+ trace .time [int (sweep - 1 ), 0 ]
133+ for sweep in self .function_output .sweep
134+ ]
135+ )
136+ else :
137+ x_axis = self .function_output .time
138+ x_range = (np .min (x_axis ), np .max (x_axis ))
139+ for channel_index , ((unit , channel ), group ) in enumerate (window_groups ):
140+ if isinstance (trace_channels , list ):
141+ channel_plot = channel_plot_dict [f"Channel { int (channel )} " f"({ unit } )" ]
142+ else :
143+ channel_plot : pg .PlotItem = self .win .addPlot (row = channel_index , col = 0 ) # type: ignore
144+ if channel_index == 0 :
145+ channel_0 = channel_plot
146+ channel_plot .addLegend ()
147+ channel_plot .setXLink (channel_0 ) # type: ignore
148+ channel_box = channel_plot .getViewBox ()
149+ channel_box .setXRange (x_range [0 ], x_range [1 ]) # type: ignore
150+ subgroups = group .groupby ("label" )
151+ label_count = len (subgroups )
152+ for i , (label , subgroup ) in enumerate (subgroups ):
153+ label_colors = utils .color_picker_qcolor (
154+ length = label_count , index = i , color = "gist_rainbow" , alpha = 0.5
155+ )
156+ if not align_onset :
157+ y_smooth = moving_average (
158+ subgroup ["measurements" ].to_numpy (),
159+ subgroup .index .size // 10 ,
160+ )
161+ channel_plot .plot (
162+ x_axis [subgroup .index ],
163+ y_smooth ,
164+ pen = label_colors ,
165+ alpha = 0.4 ,
166+ width = 2.0 ,
167+ )
168+ channel_plot .setLabel ("left" , f"Channel { int (channel )} " f"({ unit } )" )
169+ channel_plot .scatterPlot (
170+ x_axis [subgroup .index ],
171+ subgroup ["measurements" ].values ,
172+ pen = label_colors ,
173+ symbol = "o" ,
174+ symbolSize = 10 ,
175+ symbolPen = "w" ,
176+ symbolBrush = label_colors ,
177+ name = f"{ label } " ,
178+ )
179+ if show :
180+ self .show ()
181+ return self
182+
183+ def show (self ) -> pg .GraphicsLayoutWidget | None :
184+ if self .win is not None :
185+ self .win .show ()
186+ return self .win
84187
85188
86189class FunctionOutputMatplotlib (FunctionOutputPlot ):
@@ -157,7 +260,7 @@ def plot(
157260 channel_count = np .unique (self .function_output .channel ).size
158261 unique_labels = np .unique (self .function_output .label )
159262 if align_onset :
160- x_axis = self .function_output .location
263+ x_axis = self .function_output .location . copy ()
161264 else :
162265 if trace is not None :
163266 x_axis = self .function_output .location + np .array (
@@ -167,7 +270,7 @@ def plot(
167270 ]
168271 )
169272 else :
170- x_axis = self .function_output .time
273+ x_axis = self .function_output .time . copy ()
171274 for color_index , label in enumerate (unique_labels ):
172275 # add section to plot on channel by channel basis
173276 for channel_index , channel_number in enumerate (
0 commit comments