@@ -94,11 +94,15 @@ def __init__(self, alf_path, ephys_path):
9494 def filter_units (self , type ):
9595 if type == 'all' :
9696 self .spike_idx = np .arange (self .spikes ['clusters' ].size )
97- self .kp_idx = np .where (~ np .isnan (self .spikes ['depths' ][self .spike_idx ]))[0 ]
97+ # Filter for nans in depths and also in amps
98+ self .kp_idx = np .where (~ np .isnan (self .spikes ['depths' ][self .spike_idx ]) &
99+ ~ np .isnan (self .spikes ['amps' ][self .spike_idx ]))[0 ]
100+
98101 else :
99102 clust = np .where (self .clusters .metrics .ks2_label == type )
100103 self .spike_idx = np .where (np .isin (self .spikes ['clusters' ], clust ))[0 ]
101- self .kp_idx = np .where (~ np .isnan (self .spikes ['depths' ][self .spike_idx ]))[0 ]
104+ self .kp_idx = np .where (~ np .isnan (self .spikes ['depths' ][self .spike_idx ]) & ~ np .isnan (
105+ self .spikes ['amps' ][self .spike_idx ]))[0 ]
102106
103107# Plots that require spike and cluster data
104108 def get_depth_data_scatter (self ):
@@ -107,29 +111,33 @@ def get_depth_data_scatter(self):
107111 return data_scatter
108112 else :
109113 A_BIN = 10
110- amp_range = np .quantile (self .spikes ['amps' ][self .spike_idx ], [0 , 0.9 ])
114+ amp_range = np .quantile (self .spikes ['amps' ][self .spike_idx ][ self . kp_idx ] , [0 , 0.9 ])
111115 amp_bins = np .linspace (amp_range [0 ], amp_range [1 ], A_BIN )
112116 colour_bin = np .linspace (0.0 , 1.0 , A_BIN )
113117 colours = (cm .get_cmap ('BuPu' )(colour_bin )[np .newaxis , :, :3 ][0 ]) * 255
114- spikes_colours = np .empty (self .spikes ['amps' ][self .spike_idx ].size , dtype = object )
115- spikes_size = np .empty (self .spikes ['amps' ][self .spike_idx ].size )
118+ spikes_colours = np .empty (self .spikes ['amps' ][self .spike_idx ][self .kp_idx ].size ,
119+ dtype = object )
120+ spikes_size = np .empty (self .spikes ['amps' ][self .spike_idx ][self .kp_idx ].size )
116121 for iA in range (amp_bins .size - 1 ):
117- idx = np .where ((self .spikes ['amps' ][self .spike_idx ] > amp_bins [iA ]) &
118- (self .spikes ['amps' ][self .spike_idx ] <= amp_bins [iA + 1 ]))[0 ]
122+ idx = np .where ((self .spikes ['amps' ][self .spike_idx ][self .kp_idx ] > amp_bins [iA ]) &
123+ (self .spikes ['amps' ][self .spike_idx ][self .kp_idx ] <=
124+ amp_bins [iA + 1 ]))[0 ]
119125
120126 spikes_colours [idx ] = QtGui .QColor (* colours [iA ])
121127 spikes_size [idx ] = iA / (A_BIN / 4 )
122128
123129 data_scatter = {
124- 'x' : self .spikes ['times' ][self .spike_idx ][0 :- 1 :100 ],
125- 'y' : self .spikes ['depths' ][self .spike_idx ][0 :- 1 :100 ],
130+ 'x' : self .spikes ['times' ][self .spike_idx ][self . kp_idx ][ 0 :- 1 :100 ],
131+ 'y' : self .spikes ['depths' ][self .spike_idx ][self . kp_idx ][ 0 :- 1 :100 ],
126132 'levels' : amp_range * 1e6 ,
127133 'colours' : spikes_colours [0 :- 1 :100 ],
128134 'pen' : None ,
129135 'size' : spikes_size [0 :- 1 :100 ],
130136 'symbol' : np .array ('o' ),
131- 'xrange' : np .array ([np .min (self .spikes ['times' ][self .spike_idx ][0 :- 1 :100 ]),
132- np .max (self .spikes ['times' ][self .spike_idx ][0 :- 1 :100 ])]),
137+ 'xrange' : np .array ([np .min (self .spikes ['times' ][self .spike_idx ][self .kp_idx ]
138+ [0 :- 1 :100 ]),
139+ np .max (self .spikes ['times' ][self .spike_idx ][self .kp_idx ]
140+ [0 :- 1 :100 ])]),
133141 'xaxis' : 'Time (s)' ,
134142 'title' : 'Amplitude (uV)' ,
135143 'cmap' : 'BuPu' ,
@@ -148,9 +156,11 @@ def get_fr_p2t_data_scatter(self):
148156 (clu ,
149157 spike_depths ,
150158 spike_amps ,
151- n_spikes ) = self .compute_spike_average (self .spikes ['clusters' ][self .spike_idx ],
152- self .spikes ['depths' ][self .spike_idx ],
153- self .spikes ['amps' ][self .spike_idx ])
159+ n_spikes ) = self .compute_spike_average ((self .spikes ['clusters' ][self .spike_idx ]
160+ [self .kp_idx ]), (self .spikes ['depths' ]
161+ [self .spike_idx ][self .kp_idx ]),
162+ (self .spikes ['amps' ][self .spike_idx ]
163+ [self .kp_idx ]))
154164 spike_amps = spike_amps * 1e6
155165 fr = n_spikes / np .max (self .spikes ['times' ])
156166 fr_norm , fr_levels = self .normalise_data (fr , lquant = 0 , uquant = 1 )
0 commit comments