Skip to content

Commit b2d179f

Browse files
Split preprocessing into modular scripts: loader, resampler, segmenter
1 parent 49a5c4a commit b2d179f

File tree

4 files changed

+389
-162
lines changed

4 files changed

+389
-162
lines changed
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
"""
2+
MOSS - loader.py
3+
================
4+
Handles loading Muse 2 EEG data from CSV files.
5+
Supports all three Muse export formats and auto-detects sample rate.
6+
7+
Supported formats:
8+
- Mind Monitor app: RAW_TP9, RAW_AF7, RAW_AF8, RAW_TP10
9+
- MuseLSL: TP9, AF7, AF8, TP10
10+
- Direct export: channel1, channel2, channel3, channel4
11+
12+
Used by: preprocessing.py
13+
"""
14+
15+
import numpy as np
16+
import pandas as pd
17+
from typing import Optional
18+
19+
# ── Constants ──────────────────────────────────────────────────────────────────
20+
DEFAULT_FS = 256 # fallback sample rate if detection fails
21+
22+
# Column names per export format
23+
MIND_MONITOR_COLS = ['RAW_TP9', 'RAW_AF7', 'RAW_AF8', 'RAW_TP10']
24+
MUSELSL_COLS = ['TP9', 'AF7', 'AF8', 'TP10']
25+
EXPORT_COLS = ['channel1', 'channel2', 'channel3', 'channel4']
26+
27+
# Canonical channel order used throughout MOSS
28+
CHANNEL_NAMES = ['TP9', 'AF7', 'AF8', 'TP10']
29+
30+
31+
def load_csv(filepath: str) -> tuple[np.ndarray, int]:
32+
"""
33+
Load a Muse 2 CSV file and return raw EEG as a numpy array.
34+
35+
Args:
36+
filepath: path to Muse 2 CSV file
37+
38+
Returns:
39+
eeg: (n_samples, 4) float32 array, channels = [TP9, AF7, AF8, TP10]
40+
src_fs: detected sample rate in Hz
41+
42+
Raises:
43+
ValueError: if the CSV format is not recognized
44+
"""
45+
df = pd.read_csv(filepath)
46+
cols = df.columns.tolist()
47+
48+
# Detect which export format this is
49+
if 'RAW_TP9' in cols:
50+
eeg_cols = MIND_MONITOR_COLS
51+
ts_col = 'TimeStamp' if 'TimeStamp' in cols else None
52+
elif 'TP9' in cols:
53+
eeg_cols = MUSELSL_COLS
54+
ts_col = 'timestamps' if 'timestamps' in cols else None
55+
elif 'channel1' in cols:
56+
eeg_cols = EXPORT_COLS
57+
ts_col = 'time' if 'time' in cols else None
58+
else:
59+
raise ValueError(
60+
f"Unrecognized CSV format. Expected RAW_TP9/AF7/AF8/TP10, "
61+
f"TP9/AF7/AF8/TP10, or channel1-4. Got columns: {cols[:10]}"
62+
)
63+
64+
df = df[([ts_col] if ts_col else []) + eeg_cols].dropna(subset=eeg_cols)
65+
eeg = df[eeg_cols].values.astype(np.float32)
66+
67+
src_fs = detect_sample_rate(df, ts_col)
68+
69+
return eeg, src_fs
70+
71+
72+
def detect_sample_rate(df: pd.DataFrame,
73+
ts_col: Optional[str],
74+
default: int = DEFAULT_FS) -> int:
75+
"""
76+
Estimate sample rate from a timestamp column.
77+
Falls back to default if timestamps are unavailable or unparseable.
78+
79+
Args:
80+
df: DataFrame containing the timestamp column
81+
ts_col: name of the timestamp column (None if not present)
82+
default: fallback sample rate
83+
84+
Returns:
85+
sample rate in Hz (clamped to 100-512 range)
86+
"""
87+
if ts_col is None or ts_col not in df.columns:
88+
return default
89+
90+
# Try parsing as datetime strings (Mind Monitor format)
91+
try:
92+
ts = pd.to_datetime(df[ts_col])
93+
dt = (ts.iloc[-1] - ts.iloc[0]).total_seconds()
94+
if dt > 0:
95+
fs = int(round(len(df) / dt))
96+
return max(100, min(512, fs))
97+
except Exception:
98+
pass
99+
100+
# Try parsing as unix epoch floats (MuseLSL format)
101+
try:
102+
ts = df[ts_col].values.astype(float)
103+
diffs = np.diff(ts)
104+
diffs = diffs[diffs > 0]
105+
if len(diffs) > 10:
106+
fs = int(round(1.0 / np.median(diffs)))
107+
return max(100, min(512, fs))
108+
except Exception:
109+
pass
110+
111+
return default
112+
113+
114+
if __name__ == '__main__':
115+
import sys
116+
if len(sys.argv) < 2:
117+
print("Usage: python loader.py path/to/recording.csv")
118+
sys.exit(1)
119+
120+
eeg, fs = load_csv(sys.argv[1])
121+
print(f"Loaded: {eeg.shape} | Sample rate: {fs} Hz")
122+
print(f"Duration: {eeg.shape[0] / fs:.1f}s")
123+
print(f"Channels: {CHANNEL_NAMES}")
Lines changed: 45 additions & 162 deletions
Original file line numberDiff line numberDiff line change
@@ -1,145 +1,79 @@
11
"""
22
MOSS - preprocessing.py
33
=======================
4-
Handles all EEG signal preprocessing:
5-
- CSV loading (Mind Monitor + MuseLSL formats)
6-
- Sample rate detection
7-
- Resampling to 200Hz
8-
- Segmentation into 4-second windows
4+
Orchestrates the full EEG preprocessing pipeline:
5+
load -> resample -> segment
96
10-
Input: path to a Muse 2 CSV file
7+
This is the main entry point for preprocessing.
8+
Each step is implemented in its own module:
9+
loader.py — CSV loading + sample rate detection
10+
resampler.py — resampling to 200Hz
11+
segmenter.py — slicing into 4-second windows
12+
13+
Input: path to a Muse 2 CSV file (or raw numpy array for LSL streaming)
1114
Output: list of (4, 800) numpy arrays — one per segment
1215
1316
Used by: coordinator.py
1417
"""
1518

1619
import numpy as np
17-
import pandas as pd
18-
from scipy import signal as scipy_signal
1920
from typing import Optional
2021

21-
# ── Constants ──────────────────────────────────────────────────────────────────
22-
TGT_FS = 200 # NeuroLM expected sample rate (Hz)
23-
WIN_SEC = 4 # window length in seconds
24-
STEP_SEC = 2 # step size (50% overlap)
25-
WIN_SAMPLES = TGT_FS * WIN_SEC # 800 samples per window
26-
STEP_SAMPLES = TGT_FS * STEP_SEC # 400 samples per step
27-
DEFAULT_FS = 256 # fallback sample rate if detection fails
28-
29-
# Column names for different Muse export formats
30-
MIND_MONITOR_COLS = ['RAW_TP9', 'RAW_AF7', 'RAW_AF8', 'RAW_TP10']
31-
MUSELSL_COLS = ['TP9', 'AF7', 'AF8', 'TP10']
32-
EXPORT_COLS = ['channel1', 'channel2', 'channel3', 'channel4']
22+
from loader import load_csv, CHANNEL_NAMES
23+
from resampler import resample, TGT_FS
24+
from segmenter import segment, WIN_SEC, WIN_SAMPLES, STEP_SAMPLES
3325

34-
# Canonical channel order used throughout MOSS
35-
CHANNEL_NAMES = ['TP9', 'AF7', 'AF8', 'TP10']
3626

37-
38-
def load_csv(filepath: str) -> tuple[np.ndarray, int]:
27+
def preprocess(filepath: str) -> tuple[list[np.ndarray], int, float]:
3928
"""
40-
Load a Muse 2 CSV file and return raw EEG as a numpy array.
29+
Full preprocessing pipeline from CSV file: load -> resample -> segment.
4130
42-
Supports three export formats:
43-
- Mind Monitor: RAW_TP9, RAW_AF7, RAW_AF8, RAW_TP10
44-
- MuseLSL: TP9, AF7, AF8, TP10
45-
- Direct export: channel1, channel2, channel3, channel4
31+
Args:
32+
filepath: path to Muse 2 CSV file (any supported format)
4633
4734
Returns:
48-
eeg: np.ndarray of shape (n_samples, 4), dtype float32
49-
src_fs: detected sample rate in Hz
50-
"""
51-
df = pd.read_csv(filepath)
52-
cols = df.columns.tolist()
53-
54-
# Detect column format
55-
if 'RAW_TP9' in cols:
56-
eeg_cols = MIND_MONITOR_COLS
57-
ts_col = 'TimeStamp' if 'TimeStamp' in cols else None
58-
elif 'TP9' in cols:
59-
eeg_cols = MUSELSL_COLS
60-
ts_col = 'timestamps' if 'timestamps' in cols else None
61-
elif 'channel1' in cols:
62-
eeg_cols = EXPORT_COLS
63-
ts_col = 'time' if 'time' in cols else None
64-
else:
65-
raise ValueError(
66-
f"Unrecognized CSV format. Expected RAW_TP9/AF7/AF8/TP10, "
67-
f"TP9/AF7/AF8/TP10, or channel1-4. Got columns: {cols[:10]}"
68-
)
69-
70-
df = df[([ts_col] if ts_col else []) + eeg_cols].dropna(subset=eeg_cols)
71-
eeg = df[eeg_cols].values.astype(np.float32)
72-
73-
# Detect sample rate from timestamps
74-
src_fs = _detect_sample_rate(df, ts_col)
75-
76-
return eeg, src_fs
77-
78-
79-
def _detect_sample_rate(df: pd.DataFrame, ts_col: Optional[str]) -> int:
80-
"""Estimate sample rate from timestamp column, fallback to DEFAULT_FS."""
81-
if ts_col is None or ts_col not in df.columns:
82-
return DEFAULT_FS
83-
try:
84-
ts = pd.to_datetime(df[ts_col])
85-
dt = (ts.iloc[-1] - ts.iloc[0]).total_seconds()
86-
if dt > 0:
87-
fs = int(round(len(df) / dt))
88-
return max(100, min(512, fs)) # clamp to sane range
89-
except Exception:
90-
pass
91-
92-
try:
93-
# MuseLSL format: unix epoch floats
94-
ts = df[ts_col].values.astype(float)
95-
diffs = np.diff(ts)
96-
diffs = diffs[diffs > 0]
97-
if len(diffs) > 10:
98-
fs = int(round(1.0 / np.median(diffs)))
99-
return max(100, min(512, fs))
100-
except Exception:
101-
pass
35+
segments: list of (4, 800) numpy arrays, ready for encoder.py
36+
src_fs: detected source sample rate in Hz
37+
duration: recording duration in seconds
10238
103-
return DEFAULT_FS
39+
Raises:
40+
ValueError: if CSV format unrecognized or recording too short
41+
"""
42+
# Step 1 — Load
43+
eeg, src_fs = load_csv(filepath)
10444

45+
# Step 2 — Resample to 200Hz
46+
eeg = resample(eeg, src_fs)
10547

106-
def resample(eeg: np.ndarray, src_fs: int, tgt_fs: int = TGT_FS) -> np.ndarray:
107-
"""
108-
Resample EEG from src_fs to tgt_fs using Fourier method.
48+
# Step 3 — Validate duration
49+
duration = eeg.shape[0] / TGT_FS
50+
if duration < WIN_SEC:
51+
raise ValueError(
52+
f"Recording too short: {duration:.1f}s. Need at least {WIN_SEC}s."
53+
)
10954

110-
Args:
111-
eeg: (n_samples, 4) array
112-
src_fs: source sample rate
113-
tgt_fs: target sample rate (default 200Hz)
55+
# Step 4 — Segment into windows
56+
segments = segment(eeg)
11457

115-
Returns:
116-
resampled: (n_out, 4) array at tgt_fs
117-
"""
118-
if src_fs == tgt_fs:
119-
return eeg
120-
n_out = int(eeg.shape[0] * tgt_fs / src_fs)
121-
return np.stack(
122-
[scipy_signal.resample(eeg[:, c], n_out) for c in range(eeg.shape[1])],
123-
axis=1
124-
).astype(np.float32)
58+
return segments, src_fs, duration
12559

12660

12761
def from_array(raw: np.ndarray,
12862
src_fs: int,
129-
channel_order: Optional[list[str]] = None) -> tuple[list[np.ndarray], float]:
63+
channel_order: Optional[list[str]] = None
64+
) -> tuple[list[np.ndarray], float]:
13065
"""
131-
Preprocess a raw EEG numpy array directly — no CSV needed.
66+
Full preprocessing pipeline from a raw numpy array.
13267
Use this for real-time streaming from the Muse 2 headset via LSL.
13368
13469
Args:
135-
raw: (n_samples, 4) float32 array, channels in order
136-
[TP9, AF7, AF8, TP10] by default
70+
raw: (n_samples, 4) float32 array
71+
channels in order [TP9, AF7, AF8, TP10] by default
13772
src_fs: sample rate of the incoming data (e.g. 256 for Muse 2)
13873
channel_order: list of 4 channel names if different from default
139-
default: ['TP9', 'AF7', 'AF8', 'TP10']
14074
14175
Returns:
142-
segments: list of (4, 800) numpy arrays
76+
segments: list of (4, 800) numpy arrays, ready for encoder.py
14377
duration: recording duration in seconds
14478
14579
Example (LSL stream):
@@ -155,8 +89,7 @@ def from_array(raw: np.ndarray,
15589
f"Expected (n_samples, 4) array, got shape {raw.shape}"
15690
)
15791

158-
eeg = raw.astype(np.float32)
159-
eeg = resample(eeg, src_fs)
92+
eeg = resample(raw.astype(np.float32), src_fs)
16093

16194
duration = eeg.shape[0] / TGT_FS
16295
if duration < WIN_SEC:
@@ -168,64 +101,14 @@ def from_array(raw: np.ndarray,
168101
return segments, duration
169102

170103

171-
def segment(eeg: np.ndarray,
172-
win_samples: int = WIN_SAMPLES,
173-
step_samples: int = STEP_SAMPLES) -> list[np.ndarray]:
174-
"""
175-
Slice EEG into overlapping windows.
176-
177-
Args:
178-
eeg: (n_samples, 4) array at TGT_FS
179-
win_samples: samples per window (default 800 = 4s @ 200Hz)
180-
step_samples: step size (default 400 = 2s, 50% overlap)
181-
182-
Returns:
183-
segments: list of (4, win_samples) arrays — channels first
184-
"""
185-
segs, start = [], 0
186-
while start + win_samples <= eeg.shape[0]:
187-
seg = eeg[start:start + win_samples, :].T # (4, 800)
188-
segs.append(seg)
189-
start += step_samples
190-
return segs
191-
192-
193-
def preprocess(filepath: str) -> tuple[list[np.ndarray], int, float]:
194-
"""
195-
Full preprocessing pipeline: load → resample → segment.
196-
197-
Args:
198-
filepath: path to Muse 2 CSV file
199-
200-
Returns:
201-
segments: list of (4, 800) numpy arrays
202-
src_fs: detected source sample rate
203-
duration: recording duration in seconds
204-
205-
Raises:
206-
ValueError: if CSV format unrecognized or recording too short
207-
"""
208-
eeg, src_fs = load_csv(filepath)
209-
eeg = resample(eeg, src_fs)
210-
211-
duration = eeg.shape[0] / TGT_FS
212-
if duration < WIN_SEC:
213-
raise ValueError(
214-
f"Recording too short: {duration:.1f}s. Need at least {WIN_SEC}s."
215-
)
216-
217-
segments = segment(eeg)
218-
return segments, src_fs, duration
219-
220-
221104
if __name__ == '__main__':
222-
# Quick test
223105
import sys
224106
if len(sys.argv) < 2:
225107
print("Usage: python preprocessing.py path/to/recording.csv")
226108
sys.exit(1)
227109

228110
segs, fs, dur = preprocess(sys.argv[1])
229111
print(f"Sample rate detected: {fs} Hz")
230-
print(f"Duration: {dur:.1f}s")
231-
print(f"Segments: {len(segs)} x {segs[0].shape}")
112+
print(f"Duration: {dur:.1f}s")
113+
print(f"Segments: {len(segs)} x {segs[0].shape}")
114+
print(f"Channels: {CHANNEL_NAMES}")

0 commit comments

Comments
 (0)