|
| 1 | +import time |
| 2 | +import typing |
| 3 | + |
| 4 | +import numpy as np |
| 5 | +import numpy.typing as npt |
| 6 | +import xipppy as xp |
| 7 | + |
| 8 | + |
| 9 | +CLOCK_RATE = 30_000 |
| 10 | +STREAM_RATES: dict = { |
| 11 | + "raw": CLOCK_RATE, |
| 12 | + "hi-res": 2_000, |
| 13 | + "lfp": 1_000 |
| 14 | +} |
| 15 | + |
| 16 | + |
| 17 | +class RippleDevice: |
| 18 | + def __init__( |
| 19 | + self, |
| 20 | + targ_stream_type: str = "hi-res", # for now only supporting raw, hi-res, lfp |
| 21 | + fetch_delay: float = 0.004, |
| 22 | + ): |
| 23 | + if targ_stream_type not in STREAM_RATES: |
| 24 | + raise Exception(f"Stream types other than {STREAM_RATES.keys()} are not supported") |
| 25 | + |
| 26 | + self._targ_st = targ_stream_type |
| 27 | + self._fetch_delay = fetch_delay |
| 28 | + try: |
| 29 | + udp_mode = xp._open() |
| 30 | + print("Connected through UDP") |
| 31 | + except: |
| 32 | + print("Failed to connect through UDP \nAttempting TCP") |
| 33 | + try: |
| 34 | + tcp_mode = xp._open(use_tcp=True) |
| 35 | + print("Connected through TCP") |
| 36 | + except: |
| 37 | + raise Exception("Could not connect to processor. \nMake sure power LED is connected and try again") |
| 38 | + |
| 39 | + # For each electrode, enable only the target stream, disable all others |
| 40 | + electrode_ids = xp.list_elec("all") |
| 41 | + self._elec_ids: list[int] = [] |
| 42 | + for el_id in electrode_ids: |
| 43 | + for st in xp.get_fe_streams(int(el_id)): |
| 44 | + b_targ_st = st == self._targ_st |
| 45 | + xp.signal_set(el_id, st, b_targ_st) |
| 46 | + if b_targ_st: |
| 47 | + self._elec_ids.append(int(el_id)) |
| 48 | + |
| 49 | + # TODO: Filtering? |
| 50 | + |
| 51 | + self._t0 = xp.time() |
| 52 | + |
| 53 | + @property |
| 54 | + def stream_type(self): |
| 55 | + return self._targ_st |
| 56 | + |
| 57 | + @property |
| 58 | + def elec_ids(self) -> list[int]: |
| 59 | + return self._elec_ids |
| 60 | + |
| 61 | + @property |
| 62 | + def srate(self) -> float: |
| 63 | + return float(STREAM_RATES.get(self.stream_type, 1_000)) |
| 64 | + |
| 65 | + def __del__(self): |
| 66 | + for st in ["raw", "stim", "hi-res", "lfp", "spk"]: |
| 67 | + for el_id in self._elec_ids: |
| 68 | + xp.signal_set(el_id, st, False) |
| 69 | + time.sleep(1.0) |
| 70 | + xp._close() |
| 71 | + |
| 72 | + @staticmethod |
| 73 | + def time() -> int: |
| 74 | + return xp.time() |
| 75 | + |
| 76 | + def fetch(self) -> typing.Tuple[npt.NDArray[np.float32], int]: |
| 77 | + t_now = xp.time() |
| 78 | + t_elapsed = max(0, (t_now - self._t0) / CLOCK_RATE) |
| 79 | + # Only fetch up to t_now - _fetch_delay, never beyond! |
| 80 | + fetch_points = max(int((t_elapsed - self._fetch_delay) * self.srate), 0) |
| 81 | + |
| 82 | + # Fetch |
| 83 | + data, timestamp = None, 0 |
| 84 | + if fetch_points > 0: |
| 85 | + if self._targ_st == "raw": |
| 86 | + data, timestamp = xp.cont_raw(fetch_points, self._elec_ids, start_timestamp=self._t0) |
| 87 | + elif self._targ_st == "hi-res": |
| 88 | + data, timestamp = xp.cont_hires(fetch_points, self._elec_ids, start_timestamp=self._t0) |
| 89 | + elif self._targ_st == "lfp": |
| 90 | + data, timestamp = xp.cont_lfp(fetch_points, self._elec_ids, start_timestamp=self._t0) |
| 91 | + else: |
| 92 | + raise Exception(f"Unsupported ripple stream type: {self._targ_st} ") |
| 93 | + |
| 94 | + if data is not None: |
| 95 | + if not len(self._elec_ids): |
| 96 | + raise Exception("Data channel count iz zero. Are electrodes connected?") |
| 97 | + # Note atypical memory layout: channels x samples. |
| 98 | + data = np.array(data).reshape(len(self._elec_ids), -1) |
| 99 | + data = np.ascontiguousarray(data.T) |
| 100 | + if data.shape[0] > 0: |
| 101 | + if data.shape[0] != fetch_points: |
| 102 | + raise Exception("API returned unexpected number of points. Data missing.") |
| 103 | + self._t0 += int(fetch_points * CLOCK_RATE / self.srate) |
| 104 | + timestamp = self._t0 |
| 105 | + |
| 106 | + return data, timestamp |
0 commit comments