Skip to content

Commit 1737893

Browse files
committed
fix squeeze bug (turns 1-value arrays into 0-dimensional arrays)
1 parent 449aac2 commit 1737893

File tree

4 files changed

+4
-3
lines changed

4 files changed

+4
-3
lines changed
-26 Bytes
Binary file not shown.
19 Bytes
Binary file not shown.

npyx/spk_t.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def ids(dp, unit, sav=True, prnt=False, subset_selection='all', again=False):
5353
except:pass
5454
if type(unit) is int:
5555
spike_clusters = np.load(Path(dp,"spike_clusters.npy"))
56-
indices = np.nonzero(spike_clusters==unit)[0].squeeze()
56+
indices = np.nonzero(spike_clusters==unit)[0]
5757
if type(unit) not in [str, np.str_, int]:
5858
print('WARNING unit {} type ({}) not handled!'.format(unit, type(unit)))
5959
return
@@ -114,11 +114,11 @@ def trn(dp, unit, sav=True, prnt=False, subset_selection='all', again=False, enf
114114
if ds_table.shape[0]>1: # If several datasets in prophyler
115115
spike_clusters_samples = np.load(Path(dp, 'merged_clusters_spikes.npy'))
116116
dataset_mask=(spike_clusters_samples[:, 0]==ds_i); unit_mask=(spike_clusters_samples[:, 1]==unt)
117-
train = spike_clusters_samples[dataset_mask&unit_mask, 2].squeeze().astype(np.int64)
117+
train = spike_clusters_samples[dataset_mask&unit_mask, 2].astype(np.int64)
118118
else:
119119
spike_clusters = np.load(Path(ds_table['dp'][0],"spike_clusters.npy"))
120120
spike_samples = np.load(Path(ds_table['dp'][0],'spike_times.npy'))
121-
train = spike_samples[spike_clusters==unt].squeeze()
121+
train = spike_samples[spike_clusters==unt]
122122
else:
123123
try:unit=int(unit)
124124
except:pass

npyx/spk_wvf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -403,6 +403,7 @@ def get_ids_subset(dp, unit, n_waveforms, batch_size_waveforms, subset_selection
403403
else:
404404
assert n_waveforms > 0
405405
spike_ids = ids(dp, unit)
406+
assert any(spike_ids)
406407
if subset_selection == 'regular':
407408
# Regular subselection.
408409
if batch_size_waveforms is None or len(spike_ids) <= max(batch_size_waveforms, n_waveforms):

0 commit comments

Comments
 (0)