Skip to content

Commit dfe6749

Browse files
committed
fix: fix a few bugs
1 parent 9806dbf commit dfe6749

File tree

5 files changed

+67
-49
lines changed

5 files changed

+67
-49
lines changed

ephyspy/analysis.py

Lines changed: 54 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,11 @@ def plot_sweepset_ft(fts, ft, ax, **kwargs):
137137
def sweep_idx(fts, ft):
138138
try:
139139
FT = fts.lookup_sweepset_feature(ft, return_value=False)
140-
return FT.diagnostics["selected_idx"]
140+
idx = FT.diagnostics["selected_idx"]
141+
if isinstance(idx, (np.ndarray, list)):
142+
if len(idx) == 0:
143+
return slice(0)
144+
return idx
141145
except KeyError:
142146
return slice(0)
143147
except TypeError:
@@ -146,8 +150,11 @@ def sweep_idx(fts, ft):
146150

147151
def spike_idx(fts, ft):
148152
sw_idx = sweep_idx(fts, ft)
149-
FT = fts.lookup_sweep_feature(ft, return_value=False)
150-
return FT[sw_idx].diagnostics["aggregate_idx"]
153+
if not sw_idx == slice(0):
154+
FT = fts.lookup_sweep_feature(ft, return_value=False)
155+
return FT[sw_idx].diagnostics["aggregate_idx"]
156+
else:
157+
return slice(0)
151158

152159
fig, axes = plt.subplot_mosaic(mosaic, figsize=figsize, constrained_layout=True)
153160
onset = fts.lookup_sweep_feature("stim_onset")[0]
@@ -159,8 +166,9 @@ def spike_idx(fts, ft):
159166
# set
160167
selected_sweeps = {}
161168
for ft in available_sweepset_features():
162-
sweep = sweepset[sweep_idx(fts, ft)]
163-
selected_sweeps[ft] = sweep if not sweep == [] else None
169+
if not isinstance(idx := sweep_idx(fts, ft), (list, np.ndarray)):
170+
sweep = sweepset[idx]
171+
selected_sweeps[ft] = sweep if not sweep == [] else None
164172

165173
unique_sweeps = {}
166174
for k, v in selected_sweeps.items():
@@ -202,9 +210,11 @@ def spike_idx(fts, ft):
202210
plot_sweepset_ft(fts, "ap_freq_adapt", axes["fp_trace"])
203211
plot_sweepset_ft(fts, "ap_amp_slope", axes["fp_trace"])
204212

205-
stim = sweepset[sweep_idx(fts, "num_ap")].i
206-
stim_amp = int(np.max(stim) + np.min(stim))
207-
axes["fp_trace"].legend(title=f"@{stim_amp }pA")
213+
num_ap_sweep_idx = sweep_idx(fts, "num_ap")
214+
if not num_ap_sweep_idx == slice(0):
215+
stim = sweepset[num_ap_sweep_idx].i
216+
stim_amp = int(np.max(stim) + np.min(stim))
217+
axes["fp_trace"].legend(title=f"@{stim_amp }pA")
208218

209219
# different selection / aggregation
210220
# plot_sweepset_ft(fts, "ap_amp_adapt", axes["fp_trace"])
@@ -224,32 +234,36 @@ def spike_idx(fts, ft):
224234
plot_sweepset_ft(fts, "ap_adp", axes["ap_trace"])
225235
plot_sweepset_ft(fts, "ap_udr", axes["ap_trace"])
226236

227-
stim = sweepset[sweep_idx(fts, "ap_thresh")].i
228-
stim_amp = int(np.max(stim) + np.min(stim))
229-
axes["ap_trace"].legend(title=f"@{stim_amp }pA")
230-
231-
ap_sweep = sweepset[ap_sweep_idx]
232-
for i, ft in enumerate(available_spike_features()):
233-
plot_spike_feature(ap_sweep, ft, axes["ap_window"], color=f"C{i}")
234-
235-
ap_start = ap_sweep.spike_feature("threshold_t")[ap_idx] - 5e-3
236-
ap_end = ap_sweep.spike_feature("fast_trough_t")[ap_idx] + 5e-3
237-
if isinstance(ap_start, np.ndarray):
238-
ap_start = ap_start[0]
239-
ap_end = ap_end[-1]
240-
axes["ap_window"].set_xlim(ap_start, ap_end)
241-
axes["ap_trace"].axvline(ap_start, color="grey")
242-
axes["ap_trace"].axvline(ap_end, color="grey", label="selected ap")
243-
ap_sweep.plot(axes["ap_window"])
244-
axes["ap_window"].legend(loc="center left", bbox_to_anchor=(1.0, 0.5))
237+
ap_thresh_sweep_idx = sweep_idx(fts, "ap_thresh")
238+
if not ap_thresh_sweep_idx == slice(0):
239+
stim = sweepset[ap_thresh_sweep_idx].i
240+
stim_amp = int(np.max(stim) + np.min(stim))
241+
axes["ap_trace"].legend(title=f"@{stim_amp }pA")
242+
243+
if not ap_sweep_idx == slice(0):
244+
ap_sweep = sweepset[ap_sweep_idx]
245+
for i, ft in enumerate(available_spike_features()):
246+
plot_spike_feature(ap_sweep, ft, axes["ap_window"], color=f"C{i}")
247+
248+
ap_start = ap_sweep.spike_feature("threshold_t")[ap_idx] - 5e-3
249+
ap_end = ap_sweep.spike_feature("fast_trough_t")[ap_idx] + 5e-3
250+
if isinstance(ap_start, np.ndarray):
251+
ap_start = ap_start[0]
252+
ap_end = ap_end[-1]
253+
axes["ap_window"].set_xlim(ap_start, ap_end)
254+
axes["ap_trace"].axvline(ap_start, color="grey")
255+
axes["ap_trace"].axvline(ap_end, color="grey", label="selected ap")
256+
ap_sweep.plot(axes["ap_window"])
257+
axes["ap_window"].legend(loc="center left", bbox_to_anchor=(1.0, 0.5))
245258

246259
# hyperpol
247-
plot_sweepset_ft(fts, "tau", axes["set_hyperpol_fts"])
248260
plot_sweepset_ft(fts, "v_baseline", axes["set_hyperpol_fts"])
249261

250-
stim = sweepset[sweep_idx(fts, "tau")].i
251-
stim_amp = int(np.max(stim) + np.min(stim))
252-
axes["set_hyperpol_fts"].legend(title=f"@{stim_amp }pA")
262+
tau_sweep_idx = sweep_idx(fts, "tau")
263+
if not tau_sweep_idx == slice(0):
264+
stim = sweepset[tau_sweep_idx].i
265+
stim_amp = int(np.max(stim) + np.min(stim))
266+
axes["set_hyperpol_fts"].legend(title=f"@{stim_amp }pA")
253267

254268
# sag
255269
plot_sweepset_ft(fts, "sag_area", axes["sag_fts"])
@@ -261,9 +275,11 @@ def spike_idx(fts, ft):
261275
plot_sweepset_ft(fts, "sag", axes["sag_fts"])
262276
axes["sag_fts"].set_xlim(onset - 0.05, end + 0.05)
263277

264-
stim = sweepset[sweep_idx(fts, "sag")].i
265-
stim_amp = int(np.max(stim) + np.min(stim))
266-
axes["sag_fts"].legend(title=f"@{stim_amp }pA")
278+
sag_sweep_idx = sweep_idx(fts, "sag")
279+
if not sag_sweep_idx == slice(0):
280+
stim = sweepset[sag_sweep_idx].i
281+
stim_amp = int(np.max(stim) + np.min(stim))
282+
axes["sag_fts"].legend(title=f"@{stim_amp }pA")
267283

268284
# rebound
269285
plot_sweepset_ft(fts, "rebound", axes["rebound_fts"])
@@ -272,9 +288,11 @@ def spike_idx(fts, ft):
272288
plot_sweepset_ft(fts, "rebound_avg", axes["rebound_fts"])
273289
axes["rebound_fts"].set_xlim(end - 0.05, None)
274290

275-
stim = sweepset[sweep_idx(fts, "rebound")].i
276-
stim_amp = int(np.max(stim) + np.min(stim))
277-
axes["rebound_fts"].legend(title=f"@{stim_amp }pA")
291+
rebound_sweep_idx = sweep_idx(fts, "rebound")
292+
if not rebound_sweep_idx == slice(0):
293+
stim = sweepset[rebound_sweep_idx].i
294+
stim_amp = int(np.max(stim) + np.min(stim))
295+
axes["rebound_fts"].legend(title=f"@{stim_amp }pA")
278296
axes["rebound_fts"].legend(loc="center left", bbox_to_anchor=(1.0, 0.5))
279297

280298
fig.text(-0.02, 0.5, "U (mV)", va="center", rotation="vertical", fontsize=16)

ephyspy/features/spike_features.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def _plot(self, ax: Optional[Axes] = None, selected_idxs=None, **kwargs) -> Axes
9494
ax = scatter_spike_ft(
9595
"upstroke", self.data, ax=ax, selected_idxs=selected_idxs, **kwargs
9696
)
97-
kwargs["color"] = next(ax._get_lines.prop_cycler)["color"]
97+
kwargs["color"] = ax._get_lines._cycler_items[0]["color"]
9898
ax.plot(t, up_dvdt[idxs] * (t - up_t[idxs]) + up_v[idxs], **kwargs)
9999
return ax
100100

@@ -137,7 +137,7 @@ def _plot(self, ax: Optional[Axes] = None, selected_idxs=None, **kwargs) -> Axes
137137
ax = scatter_spike_ft(
138138
"downstroke", self.data, ax=ax, selected_idxs=selected_idxs, **kwargs
139139
)
140-
kwargs["color"] = next(ax._get_lines.prop_cycler)["color"]
140+
kwargs["color"] = ax._get_lines._cycler_items[0]["color"]
141141
ax.plot(t, down_dvdt[idxs] * (t - down_t[idxs]) + down_v[idxs], **kwargs)
142142
return ax
143143

ephyspy/features/sweepset_features.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,7 @@ def _aggregate(self, fts):
285285
"selected_idx": hyperpol_idx[median_idx(fts)],
286286
}
287287
)
288-
med = float("nan") if len(fts) == 0 else np.nanmedian(fts).item()
288+
med = float("nan") if len(fts) == 0 or np.all(np.isnan(fts)) else np.nanmedian(fts).item()
289289
return med
290290

291291

ephyspy/features/utils.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -264,10 +264,10 @@ def sag_period(sweep: EphysSweep, where_sag: ndarray) -> float:
264264

265265

266266
def where_stimulus(data: Union[EphysSweep, EphysSweepSet]) -> Union[bool, ndarray]:
267-
"""Checks where the stimulus is non-zero.
267+
"""Checks where the stimulus is unequal to current at t=0.
268268
269-
Checks where stimulus is non-zero for a single sweep or each sweep in a
270-
sweepset.
269+
Checks where stimulus is unequal to current at t=0 for a single sweep or each
270+
sweep in a sweepset.
271271
272272
Args:
273273
data (EphysSweep or EphysSweepSet):
@@ -276,7 +276,7 @@ def where_stimulus(data: Union[EphysSweep, EphysSweepSet]) -> Union[bool, ndarra
276276
Returns:
277277
bool: True if stimulus is non-zero.
278278
"""
279-
return data.i.T != 0
279+
return data.i.T != data.i.T[0]
280280

281281

282282
def where_spike_during_stimulus(
@@ -326,7 +326,7 @@ def has_stimulus(data: Union[EphysSweep, EphysSweepSet]) -> Union[bool, ndarray]
326326
327327
Returns:
328328
bool: True if sweep has stimulus."""
329-
return np.any(where_stimulus(data), axis=0)
329+
return np.any(data.i.T*where_stimulus(data) != 0, axis=0)
330330

331331

332332
def is_hyperpol(data: Union[EphysSweep, EphysSweepSet]) -> Union[bool, ndarray]:
@@ -338,7 +338,7 @@ def is_hyperpol(data: Union[EphysSweep, EphysSweepSet]) -> Union[bool, ndarray]:
338338
339339
Returns:
340340
bool: True if sweep is hyperpolarizing."""
341-
return np.any(data.i.T < 0, axis=0)
341+
return np.any(data.i.T*where_stimulus(data) < 0, axis=0)
342342

343343

344344
def is_depol(data: Union[EphysSweep, EphysSweepSet]) -> Union[bool, ndarray]:
@@ -350,7 +350,7 @@ def is_depol(data: Union[EphysSweep, EphysSweepSet]) -> Union[bool, ndarray]:
350350
351351
Returns:
352352
bool: True if sweep is depolarizing."""
353-
return np.any(data.i.T > 0, axis=0)
353+
return np.any(data.i.T*where_stimulus(data) > 0, axis=0)
354354

355355

356356
def has_rebound(feature: Any, T_rebound: float = 0.3) -> bool:
@@ -388,7 +388,7 @@ def median_idx(d: Union[DataFrame, ndarray]) -> Union[int, slice]:
388388
Union[int, slice]: Index of median value or slice(0) if d is empty or
389389
all nan."""
390390
d = d if isinstance(d, DataFrame) else DataFrame(d)
391-
if len(d) > 0:
391+
if len(d) > 0 and not np.all(d.isna()):
392392
is_median = d == d.median()
393393
if any(is_median):
394394
return int(d.index[is_median].to_numpy())

ephyspy/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -169,8 +169,8 @@ def stimulus_type(sweep_or_sweepset: Union[EphysSweep, EphysSweepSet]) -> str:
169169
if np.all(slope == 0):
170170
return "long_square"
171171

172-
rel_slope_change = np.abs(slope - slope[0]) / slope[0]
173-
if np.all(rel_slope_change < 0.001) and np.all(slope > 0): # same slope
172+
rel_slope_change = np.abs(slope - slope[0]) / slope[0] if np.all(slope > 0) else np.nan
173+
if np.all(rel_slope_change < 0.001) and not np.isnan(rel_slope_change): # same slope
174174
return "ramp"
175175
else:
176176
return "unknown"

0 commit comments

Comments
 (0)