Skip to content

Commit 8428d44

Browse files
authored
Merge pull request #1196 from JuliaSprenger/enh/ci_performance
Refactor test data updating and io tests
2 parents 7af346a + 21f393c commit 8428d44

File tree

7 files changed

+248
-305
lines changed

7 files changed

+248
-305
lines changed

neo/io/klustakwikio.py

Lines changed: 25 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,10 @@
1515

1616
import glob
1717
import logging
18-
import os.path
18+
from pathlib import Path
1919
import shutil
2020

21-
# note neo.core need only numpy and quantitie
21+
# note neo.core need only numpy and quantities
2222
import numpy as np
2323

2424

@@ -87,27 +87,31 @@ class KlustaKwikIO(BaseIO):
8787
extensions = ['fet', 'clu', 'res', 'spk']
8888

8989
# Operates on directories
90-
mode = 'file'
90+
mode = 'dir'
9191

92-
def __init__(self, filename, sampling_rate=30000.):
92+
def __init__(self, dirname, sampling_rate=30000.):
9393
"""Create a new IO to operate on a directory
9494
95-
filename : the directory to contain the files
96-
basename : string, basename of KlustaKwik format, or None
95+
dirname : the directory to contain the files
9796
sampling_rate : in Hz, necessary because the KlustaKwik files
9897
stores data in samples.
9998
"""
10099
BaseIO.__init__(self)
101-
# self.filename = os.path.normpath(filename)
102-
self.filename, self.basename = os.path.split(os.path.abspath(filename))
100+
self.dirname = Path(dirname)
101+
# in case no basename is provided
102+
if self.dirname.is_dir():
103+
self.session_dir = self.dirname
104+
else:
105+
self.session_dir = self.dirname.parent
106+
self.basename = self.dirname.name
103107
self.sampling_rate = float(sampling_rate)
104108

105109
# error check
106-
if not os.path.isdir(self.filename):
107-
raise ValueError("filename must be a directory")
110+
if not self.session_dir.is_dir():
111+
raise ValueError("dirname must be in an existing directory")
108112

109113
# initialize a helper object to parse filenames
110-
self._fp = FilenameParser(dirname=self.filename, basename=self.basename)
114+
self._fp = FilenameParser(dirname=self.session_dir, basename=self.basename)
111115

112116
def read_block(self, lazy=False):
113117
"""Returns a Block containing spike information.
@@ -130,7 +134,7 @@ def read_block(self, lazy=False):
130134
return block
131135

132136
# Create a single segment to hold all of the data
133-
seg = Segment(name='seg0', index=0, file_origin=self.filename)
137+
seg = Segment(name='seg0', index=0, file_origin=str(self.session_dir / self.basename))
134138
block.segments.append(seg)
135139

136140
# Load spike times from each group and store in a dict, keyed
@@ -367,15 +371,13 @@ def _make_all_file_handles(self, block):
367371

368372
def _new_group(self, id_group, nbClusters):
369373
# generate filenames
370-
fetfilename = os.path.join(self.filename,
371-
self.basename + ('.fet.%d' % id_group))
372-
clufilename = os.path.join(self.filename,
373-
self.basename + ('.clu.%d' % id_group))
374+
fetfilename = self.session_dir / (self.basename + ('.fet.%d' % id_group))
375+
clufilename = self.session_dir / (self.basename + ('.clu.%d' % id_group))
374376

375377
# back up before overwriting
376-
if os.path.exists(fetfilename):
378+
if fetfilename.exists():
377379
shutil.copyfile(fetfilename, fetfilename + '~')
378-
if os.path.exists(clufilename):
380+
if clufilename.exists():
379381
shutil.copyfile(clufilename, clufilename + '~')
380382

381383
# create file handles
@@ -406,12 +408,12 @@ def __init__(self, dirname, basename=None):
406408
will be used. An error is raised if files with multiple basenames
407409
exist in the directory.
408410
"""
409-
self.dirname = os.path.normpath(dirname)
411+
self.dirname = Path(dirname).absolute()
410412
self.basename = basename
411413

412414
# error check
413-
if not os.path.isdir(self.dirname):
414-
raise ValueError("filename must be a directory")
415+
if not self.dirname.is_dir():
416+
raise ValueError("dirname must be a directory")
415417

416418
def read_filenames(self, typestring='fet'):
417419
"""Returns filenames in the data directory matching the type.
@@ -430,13 +432,13 @@ def read_filenames(self, typestring='fet'):
430432
a sequence of digits are valid. The digits are converted to an integer
431433
and used as the group number.
432434
"""
433-
all_filenames = glob.glob(os.path.join(self.dirname, '*'))
435+
all_filenames = self.dirname.glob('*')
434436

435437
# Fill the dict with valid filenames
436438
d = {}
437439
for v in all_filenames:
438440
# Test whether matches format, ie ends with digits
439-
split_fn = os.path.split(v)[1]
441+
split_fn = v.name
440442
m = glob.re.search((r'^(\w+)\.%s\.(\d+)$' % typestring), split_fn)
441443
if m is not None:
442444
# get basename from first hit if not specified

neo/io/neuroshareapiio.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -77,26 +77,29 @@ class NeuroshareapiIO(BaseIO):
7777
# This object operates on neuroshare files
7878
mode = "file"
7979

80-
def __init__(self, filename=None, dllpath=None):
80+
def __init__(self, filename=None, dllname=None):
8181
"""
8282
Arguments:
8383
filename : the filename
84+
dllname: the path of the library to use for reading
8485
The init function will run automatically upon calling of the class, as
8586
in: test = MultichannelIO(filename = filetoberead.mcd), therefore the first
8687
operations with the file are set here, so that the user doesn't have to
8788
remember to use another method, than the ones defined in the NEO library
8889
8990
"""
9091
BaseIO.__init__(self)
91-
self.filename = filename
92+
self.filename = str(filename)
9293
# set the flags for each event type
9394
eventID = 1
9495
analogID = 2
9596
epochID = 3
9697
# if a filename was given, create a dictionary with information that will
9798
# be needed later on.
9899
if self.filename is not None:
99-
if dllpath is not None:
100+
if dllname is not None:
101+
# converting to string to also accept pathlib objects
102+
dllpath = str(dllname)
100103
name = os.path.splitext(os.path.basename(dllpath))[0]
101104
library = ns.Library(name, dllpath)
102105
else:
@@ -330,13 +333,13 @@ def read_spiketrain(self,
330333
numIndx = endat - startat
331334
# get the end point using segment duration
332335
# create a numpy empty array to store the waveforms
333-
waveforms = np.array(np.zeros([numIndx, tempSpks.max_sample_count]))
336+
waveforms = np.array(np.zeros([numIndx, 1, tempSpks.max_sample_count]))
334337
# loop through the data from the specific channel index
335338
for i in range(startat, endat, 1):
336339
# get cutout, timestamp, cutout duration, and spike unit
337340
tempCuts, timeStamp, duration, unit = tempSpks.get_data(i)
338341
# save the cutout in the waveform matrix
339-
waveforms[i] = tempCuts[0]
342+
waveforms[i, 0, :] = tempCuts[0]
340343
# append time stamp to list
341344
times.append(timeStamp)
342345

neo/io/neurosharectypesio.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def __init__(self, filename='', dllname=''):
119119
"""
120120
BaseIO.__init__(self)
121121
self.dllname = dllname
122-
self.filename = filename
122+
self.filename = str(filename)
123123

124124
def read_segment(self, import_neuroshare_segment=True,
125125
lazy=False):

0 commit comments

Comments
 (0)