1+ from .. import probe
2+ from modulefinder import Module
13import numpy as np
24import pandas as pd
35import plotly .graph_objs as go
@@ -65,19 +67,17 @@ def plot_correlogram(
6567 template = "simple_white" ,
6668 width = 350 ,
6769 height = 350 ,
68- yaxis_range = [0 , None ]
70+ yaxis_range = [0 , None ],
6971 )
7072 return fig
7173
7274
7375def plot_depth_waveforms (
76+ ephys : Module ,
7477 unit_key : dict ,
7578 y_range : float = 60 ,
7679) -> go .Figure :
7780
78- from .. import probe
79- from .. import ephys_no_curation as ephys
80-
8181 sampling_rate = (ephys .EphysRecording & unit_key ).fetch1 (
8282 "sampling_rate"
8383 ) / 1e3 # in kHz
@@ -122,13 +122,13 @@ def plot_depth_waveforms(
122122 x_min , x_max = np .min (coords [:, 0 ]), np .max (coords [:, 0 ])
123123 y_min , y_max = np .min (coords [:, 1 ]), np .max (coords [:, 1 ])
124124
125- # Spacing between channels (in um)
126- x_inc = np .abs (np .diff (coords [:, 0 ] )).min ()
127- y_inc = ( np .abs (np .diff (coords [:, 1 ]))). max ()
125+ # Spacing between recording sites (in um)
126+ x_inc = np .abs (np .diff (coords [coords [ :, 1 ] == coords [ 0 , 1 ]][:, 0 ] )).mean () / 2
127+ y_inc = np .abs (np .diff (coords [coords [ :, 0 ] == coords [ 0 , 0 ]][:, 1 ])). mean () / 2
128128
129129 time = np .arange (waveforms .shape [1 ]) / sampling_rate
130130
131- x_scale_factor = x_inc / (time [- 1 ] + 1 / sampling_rate )
131+ x_scale_factor = x_inc / (time [- 1 ] + 1 / sampling_rate ) # correspond to 1 ms
132132 time_scaled = time * x_scale_factor
133133
134134 wf_amps = waveforms .max (axis = 1 ) - waveforms .min (axis = 1 )
@@ -152,7 +152,7 @@ def plot_depth_waveforms(
152152 x = time_scaled + coord [0 ],
153153 y = wf_scaled + coord [1 ],
154154 mode = "lines" ,
155- line = dict (color = color , width = 1 ),
155+ line = dict (color = color , width = 1.5 ),
156156 hovertemplate = f"electrode { electrode } <br>"
157157 + f"x ={ coord [0 ]: .0f} μm<br>"
158158 + f"y ={ coord [1 ]: .0f} μm<extra></extra>" ,
@@ -164,7 +164,7 @@ def plot_depth_waveforms(
164164 yaxis_title = "Distance from the probe tip (μm)" ,
165165 template = "simple_white" ,
166166 width = 400 ,
167- height = 700 ,
167+ height = 600 ,
168168 xaxis_range = [x_min - x_inc / 2 , x_max + x_inc * 1.2 ],
169169 yaxis_range = [y_min - y_inc * 2 , y_max + y_inc * 2 ],
170170 )
@@ -173,12 +173,12 @@ def plot_depth_waveforms(
173173 fig .update_xaxes (tickvals = xtick_loc , ticktext = xtick_label )
174174
175175 # Add a scale bar
176- x0 = xtick_loc [0 ] / 6
177- y0 = y_min - y_inc * 1.5
176+ x0 = xtick_loc [0 ] - ( x_scale_factor * 1.5 )
177+ y0 = y_min - ( y_inc * 1.5 )
178178
179179 fig .add_trace (
180180 go .Scatter (
181- x = [x0 , xtick_loc [ 0 ] + x_scale_factor ],
181+ x = [x0 , x0 + x_scale_factor ],
182182 y = [y0 , y0 ],
183183 mode = "lines" ,
184184 line = dict (color = "black" , width = 2 ),
0 commit comments