|
| 1 | +import glob |
| 2 | +import logging |
| 3 | +import os |
| 4 | +from typing import Any |
| 5 | + |
| 6 | +import numpy as np |
| 7 | +import pandas as pd |
| 8 | + |
| 9 | +from guppy.extractors import BaseRecordingExtractor |
| 10 | + |
| 11 | +logger = logging.getLogger(__name__) |
| 12 | + |
| 13 | + |
| 14 | +class CsvRecordingExtractor(BaseRecordingExtractor): |
| 15 | + |
| 16 | + @classmethod |
| 17 | + def discover_events_and_flags(cls, folder_path) -> tuple[list[str], list[str]]: |
| 18 | + """ |
| 19 | + Discover available events and format flags from CSV files. |
| 20 | +
|
| 21 | + Parameters |
| 22 | + ---------- |
| 23 | + folder_path : str |
| 24 | + Path to the folder containing CSV files. |
| 25 | +
|
| 26 | + Returns |
| 27 | + ------- |
| 28 | + events : list of str |
| 29 | + Names of all events/stores available in the dataset. |
| 30 | + flags : list of str |
| 31 | + Format indicators or file type flags. |
| 32 | + """ |
| 33 | + logger.debug("If it exists, importing either NPM or Doric or csv file based on the structure of file") |
| 34 | + path = sorted(glob.glob(os.path.join(folder_path, "*.csv"))) |
| 35 | + |
| 36 | + path = sorted(list(set(path))) |
| 37 | + flag = "None" |
| 38 | + event_from_filename = [] |
| 39 | + flag_arr = [] |
| 40 | + for i in range(len(path)): |
| 41 | + ext = os.path.basename(path[i]).split(".")[-1] |
| 42 | + assert ext == "csv", "Only .csv files are supported by import_csv function." |
| 43 | + df = pd.read_csv(path[i], header=None, nrows=2, index_col=False, dtype=str) |
| 44 | + df = df.dropna(axis=1, how="all") |
| 45 | + df_arr = np.array(df).flatten() |
| 46 | + check_all_str = [] |
| 47 | + for element in df_arr: |
| 48 | + try: |
| 49 | + float(element) |
| 50 | + except: |
| 51 | + check_all_str.append(i) |
| 52 | + assert len(check_all_str) != len( |
| 53 | + df_arr |
| 54 | + ), "This file appears to be doric .csv. This function only supports standard .csv files." |
| 55 | + df = pd.read_csv(path[i], index_col=False) |
| 56 | + |
| 57 | + _, value = cls._check_header(df) |
| 58 | + |
| 59 | + # check dataframe structure and read data accordingly |
| 60 | + if len(value) > 0: |
| 61 | + columns_isstr = False |
| 62 | + df = pd.read_csv(path[i], header=None) |
| 63 | + cols = np.array(list(df.columns), dtype=str) |
| 64 | + else: |
| 65 | + df = df |
| 66 | + columns_isstr = True |
| 67 | + cols = np.array(list(df.columns), dtype=str) |
| 68 | + # check the structure of dataframe and assign flag to the type of file |
| 69 | + if len(cols) == 1: |
| 70 | + if cols[0].lower() != "timestamps": |
| 71 | + logger.error("\033[1m" + "Column name should be timestamps (all lower-cases)" + "\033[0m") |
| 72 | + raise Exception("\033[1m" + "Column name should be timestamps (all lower-cases)" + "\033[0m") |
| 73 | + else: |
| 74 | + flag = "event_csv" |
| 75 | + elif len(cols) == 3: |
| 76 | + arr1 = np.array(["timestamps", "data", "sampling_rate"]) |
| 77 | + arr2 = np.char.lower(np.array(cols)) |
| 78 | + if (np.sort(arr1) == np.sort(arr2)).all() == False: |
| 79 | + logger.error( |
| 80 | + "\033[1m" |
| 81 | + + "Column names should be timestamps, data and sampling_rate (all lower-cases)" |
| 82 | + + "\033[0m" |
| 83 | + ) |
| 84 | + raise Exception( |
| 85 | + "\033[1m" |
| 86 | + + "Column names should be timestamps, data and sampling_rate (all lower-cases)" |
| 87 | + + "\033[0m" |
| 88 | + ) |
| 89 | + else: |
| 90 | + flag = "data_csv" |
| 91 | + elif len(cols) == 2: |
| 92 | + raise ValueError( |
| 93 | + "Data appears to be Neurophotometrics csv. Please use import_npm_csv function to import the data." |
| 94 | + ) |
| 95 | + elif len(cols) >= 2: |
| 96 | + raise ValueError( |
| 97 | + "Data appears to be Neurophotometrics csv. Please use import_npm_csv function to import the data." |
| 98 | + ) |
| 99 | + else: |
| 100 | + logger.error("Number of columns in csv file does not make sense.") |
| 101 | + raise Exception("Number of columns in csv file does not make sense.") |
| 102 | + |
| 103 | + if columns_isstr == True and ( |
| 104 | + "flags" in np.char.lower(np.array(cols)) or "ledstate" in np.char.lower(np.array(cols)) |
| 105 | + ): |
| 106 | + flag = flag + "_v2" |
| 107 | + else: |
| 108 | + flag = flag |
| 109 | + |
| 110 | + flag_arr.append(flag) |
| 111 | + logger.info(flag) |
| 112 | + assert ( |
| 113 | + flag == "event_csv" or flag == "data_csv" |
| 114 | + ), "This function only supports standard event_csv and data_csv files." |
| 115 | + name = os.path.basename(path[i]).split(".")[0] |
| 116 | + event_from_filename.append(name) |
| 117 | + |
| 118 | + logger.info("Importing of csv file is done.") |
| 119 | + return event_from_filename, flag_arr |
| 120 | + |
| 121 | + def __init__(self, folder_path): |
| 122 | + self.folder_path = folder_path |
| 123 | + |
| 124 | + @staticmethod |
| 125 | + def _check_header(df): |
| 126 | + arr = list(df.columns) |
| 127 | + check_float = [] |
| 128 | + for i in arr: |
| 129 | + try: |
| 130 | + check_float.append(float(i)) |
| 131 | + except: |
| 132 | + pass |
| 133 | + |
| 134 | + return arr, check_float |
| 135 | + |
| 136 | + def _read_csv(self, event): |
| 137 | + logger.debug("\033[1m" + "Trying to read data for {} from csv file.".format(event) + "\033[0m") |
| 138 | + if not os.path.exists(os.path.join(self.folder_path, event + ".csv")): |
| 139 | + logger.error("\033[1m" + "No csv file found for event {}".format(event) + "\033[0m") |
| 140 | + raise Exception("\033[1m" + "No csv file found for event {}".format(event) + "\033[0m") |
| 141 | + |
| 142 | + df = pd.read_csv(os.path.join(self.folder_path, event + ".csv"), index_col=False) |
| 143 | + return df |
| 144 | + |
| 145 | + def _save_to_hdf5(self, df, event, outputPath): |
| 146 | + key = list(df.columns) |
| 147 | + |
| 148 | + # TODO: clean up these if branches |
| 149 | + if len(key) == 3: |
| 150 | + arr1 = np.array(["timestamps", "data", "sampling_rate"]) |
| 151 | + arr2 = np.char.lower(np.array(key)) |
| 152 | + if (np.sort(arr1) == np.sort(arr2)).all() == False: |
| 153 | + logger.error("\033[1m" + "Column names should be timestamps, data and sampling_rate" + "\033[0m") |
| 154 | + raise Exception("\033[1m" + "Column names should be timestamps, data and sampling_rate" + "\033[0m") |
| 155 | + |
| 156 | + if len(key) == 1: |
| 157 | + if key[0].lower() != "timestamps": |
| 158 | + logger.error("\033[1m" + "Column names should be timestamps, data and sampling_rate" + "\033[0m") |
| 159 | + raise Exception("\033[1m" + "Column name should be timestamps" + "\033[0m") |
| 160 | + |
| 161 | + if len(key) != 3 and len(key) != 1: |
| 162 | + logger.error( |
| 163 | + "\033[1m" |
| 164 | + + "Number of columns in csv file should be either three or one. Three columns if \ |
| 165 | + the file is for control or signal data or one column if the file is for event TTLs." |
| 166 | + + "\033[0m" |
| 167 | + ) |
| 168 | + raise Exception( |
| 169 | + "\033[1m" |
| 170 | + + "Number of columns in csv file should be either three or one. Three columns if \ |
| 171 | + the file is for control or signal data or one column if the file is for event TTLs." |
| 172 | + + "\033[0m" |
| 173 | + ) |
| 174 | + |
| 175 | + for i in range(len(key)): |
| 176 | + self._write_hdf5(df[key[i]].dropna(), event, outputPath, key[i].lower()) |
| 177 | + |
| 178 | + logger.info("\033[1m" + "Reading data for {} from csv file is completed.".format(event) + "\033[0m") |
| 179 | + |
| 180 | + def read(self, *, events: list[str], outputPath: str) -> list[dict[str, Any]]: |
| 181 | + output_dicts = [] |
| 182 | + for event in events: |
| 183 | + df = self._read_csv(event=event) |
| 184 | + S = df.to_dict() |
| 185 | + S["storename"] = event |
| 186 | + output_dicts.append(S) |
| 187 | + return output_dicts |
| 188 | + |
| 189 | + def save(self, *, output_dicts: list[dict[str, Any]], outputPath: str) -> None: |
| 190 | + for S in output_dicts: |
| 191 | + event = S.pop("storename") |
| 192 | + df = pd.DataFrame.from_dict(S) |
| 193 | + self._save_to_hdf5(df=df, event=event, outputPath=outputPath) |
0 commit comments