Skip to content

Commit 2941231

Browse files
committed
feat: upstreaming refactored device class from erb-thesis (WIP)
1 parent ff5f018 commit 2941231

File tree

5 files changed

+452
-297
lines changed

5 files changed

+452
-297
lines changed

eegnb/devices/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from .base import EEGDevice
2+
from .muse import MuseDevice
3+
from ._brainflow import BrainflowDevice
4+
5+
all_devices = MuseDevice.devices + BrainflowDevice.devices

eegnb/devices/_brainflow.py

Lines changed: 221 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,221 @@
1+
import logging
2+
from time import sleep
3+
from multiprocessing import Process
4+
from typing import List, Tuple
5+
6+
import numpy as np
7+
import pandas as pd
8+
9+
from brainflow import BoardShim, BoardIds, BrainFlowInputParams
10+
from .base import EEGDevice, _check_samples
11+
12+
13+
logger = logging.getLogger(__name__)
14+
15+
16+
class BrainflowDevice(EEGDevice):
17+
# list of brainflow devices
18+
devices: List[str] = [
19+
"ganglion",
20+
"ganglion_wifi",
21+
"cyton",
22+
"cyton_wifi",
23+
"cyton_daisy",
24+
"cyton_daisy_wifi",
25+
"brainbit",
26+
"unicorn",
27+
"synthetic",
28+
"brainbit",
29+
"notion1",
30+
"notion2",
31+
]
32+
33+
def __init__(
34+
self,
35+
device_name: str,
36+
serial_num=None,
37+
serial_port=None,
38+
mac_addr=None,
39+
other=None,
40+
ip_addr=None,
41+
):
42+
EEGDevice.__init__(self, device_name)
43+
self.serial_num = serial_num
44+
self.serial_port = serial_port
45+
self.mac_address = mac_addr
46+
self.other = other
47+
self.ip_addr = ip_addr
48+
self.markers: List[Tuple[List[int], float]] = []
49+
self._init_brainflow()
50+
51+
def start(self, filename: str = None, duration=None) -> None:
52+
self.save_fn = filename
53+
54+
def record():
55+
sleep(duration)
56+
self._stop_brainflow()
57+
58+
self.board.start_stream()
59+
if duration:
60+
logger.info(
61+
"Starting background recording process, will save to file: %s"
62+
% self.save_fn
63+
)
64+
self.recording = Process(target=lambda: record())
65+
self.recording.start()
66+
67+
def stop(self) -> None:
68+
self._stop_brainflow()
69+
70+
def push_sample(self, marker: List[int], timestamp: float):
71+
last_timestamp = self.board.get_current_board_data(1)[-1][0]
72+
self.markers.append((marker, last_timestamp))
73+
74+
def check(self, max_uv_abs=200) -> List[str]:
75+
data = self.board.get_board_data() # will clear board buffer
76+
# print(data)
77+
channel_names = BoardShim.get_eeg_names(self.brainflow_id)
78+
# FIXME: _check_samples expects different (Muse) inputs
79+
checked = _check_samples(data.T, channel_names, max_uv_abs=max_uv_abs) # type: ignore
80+
bads = [ch for ch, ok in checked.items() if not ok]
81+
return bads
82+
83+
def _init_brainflow(self) -> None:
84+
"""
85+
This function initializes the brainflow backend based on the input device name. It calls
86+
a utility function to determine the appropriate USB port to use based on the current operating system.
87+
Additionally, the system allows for passing a serial number in the case that they want to use either
88+
the BrainBit or the Unicorn EEG devices from the brainflow family.
89+
90+
Parameters:
91+
serial_num (str or int): serial number for either the BrainBit or Unicorn devices.
92+
"""
93+
from eegnb.devices.utils import get_openbci_usb
94+
95+
# Initialize brainflow parameters
96+
self.brainflow_params = BrainFlowInputParams()
97+
98+
device_name_to_id = {
99+
"ganglion": BoardIds.GANGLION_BOARD.value,
100+
"ganglion_wifi": BoardIds.GANGLION_WIFI_BOARD.value,
101+
"cyton": BoardIds.CYTON_BOARD.value,
102+
"cyton_wifi": BoardIds.CYTON_WIFI_BOARD.value,
103+
"cyton_daisy": BoardIds.CYTON_DAISY_BOARD.value,
104+
"cyton_daisy_wifi": BoardIds.CYTON_DAISY_WIFI_BOARD.value,
105+
"brainbit": BoardIds.BRAINBIT_BOARD.value,
106+
"unicorn": BoardIds.UNICORN_BOARD.value,
107+
"callibri_eeg": BoardIds.CALLIBRI_EEG_BOARD.value,
108+
"notion1": BoardIds.NOTION_1_BOARD.value,
109+
"notion2": BoardIds.NOTION_2_BOARD.value,
110+
"synthetic": BoardIds.SYNTHETIC_BOARD.value,
111+
}
112+
113+
# validate mapping
114+
assert all(name in device_name_to_id for name in self.devices)
115+
116+
self.brainflow_id = device_name_to_id[self.device_name]
117+
118+
if self.device_name == "ganglion":
119+
if self.serial_port is None:
120+
self.brainflow_params.serial_port = get_openbci_usb()
121+
# set mac address parameter in case
122+
if self.mac_address is None:
123+
logger.info(
124+
"No MAC address provided, attempting to connect without one"
125+
)
126+
else:
127+
self.brainflow_params.mac_address = self.mac_address
128+
129+
elif self.device_name in ["ganglion_wifi", "cyton_wifi", "cyton_daisy_wifi"]:
130+
if self.ip_addr is not None:
131+
self.brainflow_params.ip_address = self.ip_addr
132+
133+
elif self.device_name in ["cyton", "cyton_daisy"]:
134+
if self.serial_port is None:
135+
self.brainflow_params.serial_port = get_openbci_usb()
136+
137+
elif self.device_name == "callibri_eeg":
138+
if self.other:
139+
self.brainflow_params.other_info = str(self.other)
140+
141+
# some devices allow for an optional serial number parameter for better connection
142+
if self.serial_num:
143+
self.brainflow_params.serial_number = str(self.serial_num)
144+
145+
if self.serial_port:
146+
self.brainflow_params.serial_port = str(self.serial_port)
147+
148+
# Initialize board_shim
149+
self.sfreq = BoardShim.get_sampling_rate(self.brainflow_id)
150+
self.board = BoardShim(self.brainflow_id, self.brainflow_params)
151+
self.board.prepare_session()
152+
153+
def get_data(self) -> pd.DataFrame:
154+
from eegnb.devices.utils import create_stim_array
155+
156+
data = self.board.get_board_data() # will clear board buffer
157+
158+
# transform data for saving
159+
data = data.T # transpose data
160+
print(data)
161+
162+
# get the channel names for EEG data
163+
if self.brainflow_id == BoardIds.GANGLION_BOARD.value:
164+
# if a ganglion is used, use recommended default EEG channel names
165+
ch_names = ["fp1", "fp2", "tp7", "tp8"]
166+
else:
167+
# otherwise select eeg channel names via brainflow API
168+
ch_names = BoardShim.get_eeg_names(self.brainflow_id)
169+
170+
# pull EEG channel data via brainflow API
171+
eeg_data = data[:, BoardShim.get_eeg_channels(self.brainflow_id)]
172+
timestamps = data[:, BoardShim.get_timestamp_channel(self.brainflow_id)]
173+
174+
# Create a column for the stimuli to append to the EEG data
175+
stim_array = create_stim_array(timestamps, self.markers)
176+
timestamps = timestamps[
177+
..., None
178+
] # Add an additional dimension so that shapes match
179+
total_data = np.append(timestamps, eeg_data, 1)
180+
total_data = np.append(
181+
total_data, stim_array, 1
182+
) # Append the stim array to data.
183+
184+
# Subtract five seconds of settling time from beginning
185+
# total_data = total_data[5 * self.sfreq :]
186+
df = pd.DataFrame(total_data, columns=["timestamps"] + ch_names + ["stim"])
187+
return df
188+
189+
def _save(self) -> None:
190+
"""Saves the data to a CSV file."""
191+
assert self.save_fn
192+
df = self.get_data()
193+
df.to_csv(self.save_fn, index=False)
194+
195+
def _stop_brainflow(self) -> None:
196+
"""This functions kills the brainflow backend and saves the data to a CSV file."""
197+
# Collect session data and kill session
198+
if self.save_fn:
199+
self._save()
200+
self.board.stop_stream()
201+
self.board.release_session()
202+
203+
204+
def test_check():
205+
device = BrainflowDevice(device_name="synthetic")
206+
with device:
207+
sleep(2) # is 2s really needed?
208+
bads = device.check(max_uv_abs=300)
209+
# Seems to blink between the two...
210+
assert bads == ["F6", "F8"] or bads == ["F4", "F6", "F8"]
211+
# print(bads)
212+
# assert not bads
213+
214+
215+
def test_get_data():
216+
device = BrainflowDevice(device_name="synthetic")
217+
with device:
218+
sleep(2)
219+
df = device.get_data()
220+
print(df)
221+
assert not df.empty

eegnb/devices/base.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
"""
2+
Abstraction for the various supported EEG devices.
3+
"""
4+
5+
import logging
6+
from typing import List, Dict
7+
from abc import ABCMeta, abstractmethod
8+
9+
import numpy as np
10+
11+
12+
logger = logging.getLogger(__name__)
13+
14+
15+
def _check_samples(
16+
buffer: np.ndarray, channels: List[str], max_uv_abs=200
17+
) -> Dict[str, bool]:
18+
# TODO: Better signal quality check
19+
chmax = dict(zip(channels, np.max(np.abs(buffer), axis=0)))
20+
return {ch: maxval < max_uv_abs for ch, maxval in chmax.items()}
21+
22+
23+
def test_check_samples():
24+
buffer = np.array([[9.0, 11.0, -5, -13]])
25+
assert {"TP9": True, "AF7": False, "AF8": True, "TP10": False} == _check_samples(
26+
buffer, channels=["TP9", "AF7", "AF8", "TP10"], max_uv_abs=10
27+
)
28+
29+
30+
class EEGDevice(metaclass=ABCMeta):
31+
def __init__(self, device: str) -> None:
32+
"""
33+
The initialization function takes the name of the EEG device and initializes the appropriate backend.
34+
35+
Parameters:
36+
device (str): name of eeg device used for reading data.
37+
"""
38+
self.device_name = device
39+
40+
@classmethod
41+
def create(cls, device_name: str, *args, **kwargs) -> "EEGDevice":
42+
from .muse import MuseDevice
43+
from ._brainflow import BrainflowDevice
44+
45+
if device_name in BrainflowDevice.devices:
46+
return BrainflowDevice(device_name)
47+
elif device_name in MuseDevice.devices:
48+
return MuseDevice(device_name)
49+
else:
50+
raise ValueError(f"Invalid device name: {device_name}")
51+
52+
def __enter__(self):
53+
self.start()
54+
55+
def __exit__(self, *args):
56+
self.stop()
57+
58+
@abstractmethod
59+
def start(self, filename: str = None, duration=None):
60+
"""
61+
Starts the EEG device based on the defined backend.
62+
63+
Parameters:
64+
filename (str): name of the file to save the sessions data to.
65+
"""
66+
raise NotImplementedError
67+
68+
@abstractmethod
69+
def stop(self):
70+
raise NotImplementedError
71+
72+
@abstractmethod
73+
def push_sample(self, marker: List[int], timestamp: float):
74+
"""
75+
Push a marker and its timestamp to store alongside the EEG data.
76+
77+
Parameters:
78+
marker (int): marker number for the stimuli being presented.
79+
timestamp (float): timestamp of stimulus onset from time.time() function.
80+
"""
81+
raise NotImplementedError
82+
83+
def get_samples(self):
84+
raise NotImplementedError
85+
86+
@abstractmethod
87+
def check(self):
88+
raise NotImplementedError
89+
90+
91+
def test_create():
92+
device = EEGDevice.create("synthetic")
93+
assert device

0 commit comments

Comments
 (0)