Skip to content

Commit b58994d

Browse files
RecordingExtractors (#171)
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent 3b45d09 commit b58994d

14 files changed

+1721
-1055
lines changed

src/guppy/extractors/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from .base_recording_extractor import BaseRecordingExtractor, read_and_save_event, read_and_save_all_events
2+
from .tdt_recording_extractor import TdtRecordingExtractor
3+
from .csv_recording_extractor import CsvRecordingExtractor
4+
from .doric_recording_extractor import DoricRecordingExtractor
5+
from .npm_recording_extractor import NpmRecordingExtractor
Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
"""Base class for recording extractors."""
2+
3+
import logging
4+
import multiprocessing as mp
5+
import os
6+
import time
7+
from abc import ABC, abstractmethod
8+
from itertools import repeat
9+
from typing import Any
10+
11+
import h5py
12+
import numpy as np
13+
14+
logger = logging.getLogger(__name__)
15+
16+
17+
class BaseRecordingExtractor(ABC):
18+
"""
19+
Abstract base class for recording extractors.
20+
21+
Defines the interface contract for reading and saving fiber photometry
22+
data from various acquisition formats (TDT, Doric, CSV, NPM, etc.).
23+
"""
24+
25+
@classmethod
26+
@abstractmethod
27+
def discover_events_and_flags(cls) -> tuple[list[str], list[str]]:
28+
"""
29+
Discover available events and format flags from data files.
30+
31+
Returns
32+
-------
33+
events : list of str
34+
Names of all events/stores available in the dataset.
35+
flags : list of str
36+
Format indicators or file type flags.
37+
"""
38+
# NOTE: This method signature is intentionally minimal and flexible.
39+
# Different formats have different discovery requirements:
40+
# - TDT/CSV/Doric: need only folder_path parameter
41+
# - NPM: needs folder_path, num_ch, and optional inputParameters for interleaved channels
42+
# Each child class defines its own signature with the parameters it needs.
43+
pass
44+
45+
@abstractmethod
46+
def read(self, *, events: list[str], outputPath: str) -> list[dict[str, Any]]:
47+
"""
48+
Read data from source files for specified events.
49+
50+
Parameters
51+
----------
52+
events : list of str
53+
List of event/store names to extract from the data.
54+
outputPath : str
55+
Path to the output directory.
56+
57+
Returns
58+
-------
59+
list of dict
60+
List of dictionaries containing extracted data. Each dictionary
61+
represents one event/store and contains keys such as 'storename',
62+
'timestamps', 'data', 'sampling_rate', etc.
63+
"""
64+
pass
65+
66+
@abstractmethod
67+
def save(self, *, output_dicts: list[dict[str, Any]], outputPath: str) -> None:
68+
"""
69+
Save extracted data dictionaries to HDF5 format.
70+
71+
Parameters
72+
----------
73+
output_dicts : list of dict
74+
List of data dictionaries from read().
75+
outputPath : str
76+
Path to the output directory.
77+
"""
78+
pass
79+
80+
@staticmethod
81+
def _write_hdf5(data: Any, storename: str, output_path: str, key: str) -> None:
82+
"""
83+
Write data to HDF5 file.
84+
85+
Parameters
86+
----------
87+
data : array-like
88+
Data to write to the HDF5 file.
89+
storename : str
90+
Name of the store/event.
91+
output_path : str
92+
Directory path where HDF5 file will be written.
93+
key : str
94+
Key name for this data field in the HDF5 file.
95+
"""
96+
# Replace invalid characters in storename to avoid filesystem errors
97+
storename = storename.replace("\\", "_")
98+
storename = storename.replace("/", "_")
99+
100+
filepath = os.path.join(output_path, storename + ".hdf5")
101+
102+
# Create new file if it doesn't exist
103+
if not os.path.exists(filepath):
104+
with h5py.File(filepath, "w") as f:
105+
if isinstance(data, np.ndarray):
106+
f.create_dataset(key, data=data, maxshape=(None,), chunks=True)
107+
else:
108+
f.create_dataset(key, data=data)
109+
# Append to existing file
110+
else:
111+
with h5py.File(filepath, "r+") as f:
112+
if key in list(f.keys()):
113+
if isinstance(data, np.ndarray):
114+
f[key].resize(data.shape)
115+
arr = f[key]
116+
arr[:] = data
117+
else:
118+
arr = f[key]
119+
arr[()] = data
120+
else:
121+
if isinstance(data, np.ndarray):
122+
f.create_dataset(key, data=data, maxshape=(None,), chunks=True)
123+
else:
124+
f.create_dataset(key, data=data)
125+
126+
127+
def read_and_save_event(extractor, event, outputPath):
128+
output_dicts = extractor.read(events=[event], outputPath=outputPath)
129+
extractor.save(output_dicts=output_dicts, outputPath=outputPath)
130+
logger.info("Data for event {} fetched and stored.".format(event))
131+
132+
133+
def read_and_save_all_events(extractor, events, outputPath, numProcesses=mp.cpu_count()):
134+
logger.info("Reading data for event {} ...".format(events))
135+
136+
start = time.time()
137+
with mp.Pool(numProcesses) as p:
138+
p.starmap(read_and_save_event, zip(repeat(extractor), events, repeat(outputPath)))
139+
logger.info("Time taken = {0:.5f}".format(time.time() - start))
Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
import glob
2+
import logging
3+
import os
4+
from typing import Any
5+
6+
import numpy as np
7+
import pandas as pd
8+
9+
from guppy.extractors import BaseRecordingExtractor
10+
11+
logger = logging.getLogger(__name__)
12+
13+
14+
class CsvRecordingExtractor(BaseRecordingExtractor):
15+
16+
@classmethod
17+
def discover_events_and_flags(cls, folder_path) -> tuple[list[str], list[str]]:
18+
"""
19+
Discover available events and format flags from CSV files.
20+
21+
Parameters
22+
----------
23+
folder_path : str
24+
Path to the folder containing CSV files.
25+
26+
Returns
27+
-------
28+
events : list of str
29+
Names of all events/stores available in the dataset.
30+
flags : list of str
31+
Format indicators or file type flags.
32+
"""
33+
logger.debug("If it exists, importing either NPM or Doric or csv file based on the structure of file")
34+
path = sorted(glob.glob(os.path.join(folder_path, "*.csv")))
35+
36+
path = sorted(list(set(path)))
37+
flag = "None"
38+
event_from_filename = []
39+
flag_arr = []
40+
for i in range(len(path)):
41+
ext = os.path.basename(path[i]).split(".")[-1]
42+
assert ext == "csv", "Only .csv files are supported by import_csv function."
43+
df = pd.read_csv(path[i], header=None, nrows=2, index_col=False, dtype=str)
44+
df = df.dropna(axis=1, how="all")
45+
df_arr = np.array(df).flatten()
46+
check_all_str = []
47+
for element in df_arr:
48+
try:
49+
float(element)
50+
except:
51+
check_all_str.append(i)
52+
assert len(check_all_str) != len(
53+
df_arr
54+
), "This file appears to be doric .csv. This function only supports standard .csv files."
55+
df = pd.read_csv(path[i], index_col=False)
56+
57+
_, value = cls._check_header(df)
58+
59+
# check dataframe structure and read data accordingly
60+
if len(value) > 0:
61+
columns_isstr = False
62+
df = pd.read_csv(path[i], header=None)
63+
cols = np.array(list(df.columns), dtype=str)
64+
else:
65+
df = df
66+
columns_isstr = True
67+
cols = np.array(list(df.columns), dtype=str)
68+
# check the structure of dataframe and assign flag to the type of file
69+
if len(cols) == 1:
70+
if cols[0].lower() != "timestamps":
71+
logger.error("\033[1m" + "Column name should be timestamps (all lower-cases)" + "\033[0m")
72+
raise Exception("\033[1m" + "Column name should be timestamps (all lower-cases)" + "\033[0m")
73+
else:
74+
flag = "event_csv"
75+
elif len(cols) == 3:
76+
arr1 = np.array(["timestamps", "data", "sampling_rate"])
77+
arr2 = np.char.lower(np.array(cols))
78+
if (np.sort(arr1) == np.sort(arr2)).all() == False:
79+
logger.error(
80+
"\033[1m"
81+
+ "Column names should be timestamps, data and sampling_rate (all lower-cases)"
82+
+ "\033[0m"
83+
)
84+
raise Exception(
85+
"\033[1m"
86+
+ "Column names should be timestamps, data and sampling_rate (all lower-cases)"
87+
+ "\033[0m"
88+
)
89+
else:
90+
flag = "data_csv"
91+
elif len(cols) == 2:
92+
raise ValueError(
93+
"Data appears to be Neurophotometrics csv. Please use import_npm_csv function to import the data."
94+
)
95+
elif len(cols) >= 2:
96+
raise ValueError(
97+
"Data appears to be Neurophotometrics csv. Please use import_npm_csv function to import the data."
98+
)
99+
else:
100+
logger.error("Number of columns in csv file does not make sense.")
101+
raise Exception("Number of columns in csv file does not make sense.")
102+
103+
if columns_isstr == True and (
104+
"flags" in np.char.lower(np.array(cols)) or "ledstate" in np.char.lower(np.array(cols))
105+
):
106+
flag = flag + "_v2"
107+
else:
108+
flag = flag
109+
110+
flag_arr.append(flag)
111+
logger.info(flag)
112+
assert (
113+
flag == "event_csv" or flag == "data_csv"
114+
), "This function only supports standard event_csv and data_csv files."
115+
name = os.path.basename(path[i]).split(".")[0]
116+
event_from_filename.append(name)
117+
118+
logger.info("Importing of csv file is done.")
119+
return event_from_filename, flag_arr
120+
121+
def __init__(self, folder_path):
122+
self.folder_path = folder_path
123+
124+
@staticmethod
125+
def _check_header(df):
126+
arr = list(df.columns)
127+
check_float = []
128+
for i in arr:
129+
try:
130+
check_float.append(float(i))
131+
except:
132+
pass
133+
134+
return arr, check_float
135+
136+
def _read_csv(self, event):
137+
logger.debug("\033[1m" + "Trying to read data for {} from csv file.".format(event) + "\033[0m")
138+
if not os.path.exists(os.path.join(self.folder_path, event + ".csv")):
139+
logger.error("\033[1m" + "No csv file found for event {}".format(event) + "\033[0m")
140+
raise Exception("\033[1m" + "No csv file found for event {}".format(event) + "\033[0m")
141+
142+
df = pd.read_csv(os.path.join(self.folder_path, event + ".csv"), index_col=False)
143+
return df
144+
145+
def _save_to_hdf5(self, df, event, outputPath):
146+
key = list(df.columns)
147+
148+
# TODO: clean up these if branches
149+
if len(key) == 3:
150+
arr1 = np.array(["timestamps", "data", "sampling_rate"])
151+
arr2 = np.char.lower(np.array(key))
152+
if (np.sort(arr1) == np.sort(arr2)).all() == False:
153+
logger.error("\033[1m" + "Column names should be timestamps, data and sampling_rate" + "\033[0m")
154+
raise Exception("\033[1m" + "Column names should be timestamps, data and sampling_rate" + "\033[0m")
155+
156+
if len(key) == 1:
157+
if key[0].lower() != "timestamps":
158+
logger.error("\033[1m" + "Column names should be timestamps, data and sampling_rate" + "\033[0m")
159+
raise Exception("\033[1m" + "Column name should be timestamps" + "\033[0m")
160+
161+
if len(key) != 3 and len(key) != 1:
162+
logger.error(
163+
"\033[1m"
164+
+ "Number of columns in csv file should be either three or one. Three columns if \
165+
the file is for control or signal data or one column if the file is for event TTLs."
166+
+ "\033[0m"
167+
)
168+
raise Exception(
169+
"\033[1m"
170+
+ "Number of columns in csv file should be either three or one. Three columns if \
171+
the file is for control or signal data or one column if the file is for event TTLs."
172+
+ "\033[0m"
173+
)
174+
175+
for i in range(len(key)):
176+
self._write_hdf5(df[key[i]].dropna(), event, outputPath, key[i].lower())
177+
178+
logger.info("\033[1m" + "Reading data for {} from csv file is completed.".format(event) + "\033[0m")
179+
180+
def read(self, *, events: list[str], outputPath: str) -> list[dict[str, Any]]:
181+
output_dicts = []
182+
for event in events:
183+
df = self._read_csv(event=event)
184+
S = df.to_dict()
185+
S["storename"] = event
186+
output_dicts.append(S)
187+
return output_dicts
188+
189+
def save(self, *, output_dicts: list[dict[str, Any]], outputPath: str) -> None:
190+
for S in output_dicts:
191+
event = S.pop("storename")
192+
df = pd.DataFrame.from_dict(S)
193+
self._save_to_hdf5(df=df, event=event, outputPath=outputPath)

0 commit comments

Comments
 (0)