Skip to content

Commit 4710543

Browse files
committed
refactor klustakwik to use pathlib instead of os.path
1 parent 60a0d9b commit 4710543

File tree

1 file changed

+25
-23
lines changed

1 file changed

+25
-23
lines changed

neo/io/klustakwikio.py

Lines changed: 25 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,10 @@
1515

1616
import glob
1717
import logging
18-
import os.path
18+
from pathlib import Path
1919
import shutil
2020

21-
# note neo.core need only numpy and quantitie
21+
# note neo.core need only numpy and quantities
2222
import numpy as np
2323

2424

@@ -87,27 +87,31 @@ class KlustaKwikIO(BaseIO):
8787
extensions = ['fet', 'clu', 'res', 'spk']
8888

8989
# Operates on directories
90-
mode = 'file'
90+
mode = 'dir'
9191

92-
def __init__(self, filename, sampling_rate=30000.):
92+
def __init__(self, dirname, sampling_rate=30000.):
9393
"""Create a new IO to operate on a directory
9494
95-
filename : the directory to contain the files
96-
basename : string, basename of KlustaKwik format, or None
95+
dirname : the directory to contain the files
9796
sampling_rate : in Hz, necessary because the KlustaKwik files
9897
stores data in samples.
9998
"""
10099
BaseIO.__init__(self)
101-
# self.filename = os.path.normpath(filename)
102-
self.filename, self.basename = os.path.split(os.path.abspath(filename))
100+
self.dirname = Path(dirname)
101+
# in case no basename is provided
102+
if self.dirname.is_dir():
103+
self.session_dir = self.dirname
104+
else:
105+
self.session_dir = self.dirname.parent
106+
self.basename = self.dirname.name
103107
self.sampling_rate = float(sampling_rate)
104108

105109
# error check
106-
if not os.path.isdir(self.filename):
107-
raise ValueError("filename must be a directory")
110+
if not self.session_dir.is_dir():
111+
raise ValueError("dirname must be in an existing directory")
108112

109113
# initialize a helper object to parse filenames
110-
self._fp = FilenameParser(dirname=self.filename, basename=self.basename)
114+
self._fp = FilenameParser(dirname=self.session_dir, basename=self.basename)
111115

112116
def read_block(self, lazy=False):
113117
"""Returns a Block containing spike information.
@@ -130,7 +134,7 @@ def read_block(self, lazy=False):
130134
return block
131135

132136
# Create a single segment to hold all of the data
133-
seg = Segment(name='seg0', index=0, file_origin=self.filename)
137+
seg = Segment(name='seg0', index=0, file_origin=str(self.session_dir / self.basename))
134138
block.segments.append(seg)
135139

136140
# Load spike times from each group and store in a dict, keyed
@@ -367,15 +371,13 @@ def _make_all_file_handles(self, block):
367371

368372
def _new_group(self, id_group, nbClusters):
369373
# generate filenames
370-
fetfilename = os.path.join(self.filename,
371-
self.basename + ('.fet.%d' % id_group))
372-
clufilename = os.path.join(self.filename,
373-
self.basename + ('.clu.%d' % id_group))
374+
fetfilename = self.session_dir / (self.basename + ('.fet.%d' % id_group))
375+
clufilename = self.session_dir / (self.basename + ('.clu.%d' % id_group))
374376

375377
# back up before overwriting
376-
if os.path.exists(fetfilename):
378+
if fetfilename.exists():
377379
shutil.copyfile(fetfilename, fetfilename + '~')
378-
if os.path.exists(clufilename):
380+
if clufilename.exists():
379381
shutil.copyfile(clufilename, clufilename + '~')
380382

381383
# create file handles
@@ -406,12 +408,12 @@ def __init__(self, dirname, basename=None):
406408
will be used. An error is raised if files with multiple basenames
407409
exist in the directory.
408410
"""
409-
self.dirname = os.path.normpath(dirname)
411+
self.dirname = Path(dirname).absolute()
410412
self.basename = basename
411413

412414
# error check
413-
if not os.path.isdir(self.dirname):
414-
raise ValueError("filename must be a directory")
415+
if not self.dirname.is_dir():
416+
raise ValueError("dirname must be a directory")
415417

416418
def read_filenames(self, typestring='fet'):
417419
"""Returns filenames in the data directory matching the type.
@@ -430,13 +432,13 @@ def read_filenames(self, typestring='fet'):
430432
a sequence of digits are valid. The digits are converted to an integer
431433
and used as the group number.
432434
"""
433-
all_filenames = glob.glob(os.path.join(self.dirname, '*'))
435+
all_filenames = self.dirname.glob('*')
434436

435437
# Fill the dict with valid filenames
436438
d = {}
437439
for v in all_filenames:
438440
# Test whether matches format, ie ends with digits
439-
split_fn = os.path.split(v)[1]
441+
split_fn = v.name
440442
m = glob.re.search((r'^(\w+)\.%s\.(\d+)$' % typestring), split_fn)
441443
if m is not None:
442444
# get basename from first hit if not specified

0 commit comments

Comments
 (0)