Skip to content

Commit 49a5c4a

Browse files
Add modular ML scripts: preprocessing, encoder, classifier, coordinator
1 parent 2e7d969 commit 49a5c4a

File tree

4 files changed

+872
-0
lines changed

4 files changed

+872
-0
lines changed
Lines changed: 232 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,232 @@
1+
"""
2+
MOSS - classifier.py
3+
====================
4+
Handles MLP classifier training, saving, loading, and prediction.
5+
6+
Input: (N, 768) numpy embeddings from encoder.py
7+
Output: predicted labels + confidence scores
8+
9+
Each task (activity, focus, emotion, stress) has its own saved .pkl file.
10+
New tasks can be added by training a new classifier on embeddings for that task.
11+
12+
Used by: coordinator.py
13+
"""
14+
15+
import os
16+
import pickle
17+
import numpy as np
18+
from typing import Optional
19+
from sklearn.preprocessing import StandardScaler
20+
from sklearn.neural_network import MLPClassifier
21+
from sklearn.model_selection import StratifiedKFold
22+
from sklearn.metrics import accuracy_score, balanced_accuracy_score
23+
from sklearn.utils.class_weight import compute_sample_weight
24+
from collections import Counter
25+
26+
# ── Default paths ──────────────────────────────────────────────────────────────
27+
DEFAULT_MODELS_DIR = os.path.join(os.path.dirname(__file__), 'moss_models')
28+
29+
# ── Task → classifier file mapping ────────────────────────────────────────────
30+
TASK_CLASSIFIER_MAP = {
31+
'activity': 'muse2_classifier.pkl',
32+
'focus': 'focus_classifier.pkl',
33+
'emotion': 'emotion_classifier.pkl',
34+
'stress': 'stress_classifier.pkl',
35+
}
36+
37+
38+
class MossClassifier:
39+
"""
40+
Thin wrapper around sklearn MLP for MOSS mental state classification.
41+
42+
Handles:
43+
- Training with optional class balancing
44+
- Saving/loading to .pkl
45+
- Predicting labels + confidence scores from embeddings
46+
"""
47+
48+
def __init__(self,
49+
task: str,
50+
label_names: list[str],
51+
models_dir: str = DEFAULT_MODELS_DIR):
52+
"""
53+
Args:
54+
task: task name (e.g. 'activity', 'focus', 'emotion')
55+
label_names: ordered list of class names (index = class id)
56+
models_dir: directory where .pkl files are saved/loaded
57+
"""
58+
self.task = task
59+
self.label_names = label_names
60+
self.models_dir = models_dir
61+
self.clf = None
62+
self.scaler = None
63+
os.makedirs(models_dir, exist_ok=True)
64+
65+
@property
66+
def pkl_path(self) -> str:
67+
filename = TASK_CLASSIFIER_MAP.get(self.task, f'{self.task}_classifier.pkl')
68+
return os.path.join(self.models_dir, filename)
69+
70+
def train(self,
71+
embeddings: np.ndarray,
72+
labels: np.ndarray,
73+
balance_classes: bool = True,
74+
n_splits: int = 5) -> dict:
75+
"""
76+
Train MLP classifier on embeddings with optional k-fold CV evaluation.
77+
78+
Args:
79+
embeddings: (N, 768) array
80+
labels: (N,) integer class labels
81+
balance_classes: use sample weights to handle class imbalance
82+
n_splits: number of CV folds (set to 0 to skip CV)
83+
84+
Returns:
85+
results: dict with accuracy, balanced_accuracy, per-fold scores
86+
"""
87+
results = {}
88+
89+
# Optional cross-validation evaluation
90+
if n_splits > 1:
91+
skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=42)
92+
fold_accs, fold_bals = [], []
93+
94+
for tr, te in skf.split(embeddings, labels):
95+
scaler = StandardScaler()
96+
X_tr = scaler.fit_transform(embeddings[tr])
97+
X_te = scaler.transform(embeddings[te])
98+
99+
clf = self._make_mlp()
100+
sw = compute_sample_weight('balanced', labels[tr]) if balance_classes else None
101+
clf.fit(X_tr, labels[tr], sw)
102+
103+
preds = clf.predict(X_te)
104+
fold_accs.append(accuracy_score(labels[te], preds))
105+
fold_bals.append(balanced_accuracy_score(labels[te], preds))
106+
107+
results['cv_accuracy'] = float(np.mean(fold_accs))
108+
results['cv_balanced_accuracy'] = float(np.mean(fold_bals))
109+
results['cv_fold_accuracies'] = [float(x) for x in fold_accs]
110+
111+
# Train final classifier on all data
112+
self.scaler = StandardScaler()
113+
X_all = self.scaler.fit_transform(embeddings)
114+
self.clf = self._make_mlp()
115+
sw = compute_sample_weight('balanced', labels) if balance_classes else None
116+
self.clf.fit(X_all, labels, sw)
117+
118+
results['n_samples'] = len(labels)
119+
results['n_classes'] = len(np.unique(labels))
120+
results['label_names'] = self.label_names
121+
results['class_distribution'] = {
122+
self.label_names[k]: int(v)
123+
for k, v in sorted(Counter(labels).items())
124+
}
125+
126+
return results
127+
128+
def _make_mlp(self) -> MLPClassifier:
129+
return MLPClassifier(
130+
hidden_layer_sizes=(256, 128),
131+
max_iter=500,
132+
random_state=42,
133+
early_stopping=True,
134+
n_iter_no_change=20
135+
)
136+
137+
def save(self) -> str:
138+
"""Save trained classifier + scaler to .pkl. Returns path."""
139+
if self.clf is None or self.scaler is None:
140+
raise RuntimeError("Classifier not trained yet. Call train() first.")
141+
142+
bundle = {
143+
'classifier': self.clf,
144+
'scaler': self.scaler,
145+
'label_names': self.label_names,
146+
'activities': self.label_names, # kept for predict.py compatibility
147+
'task': self.task,
148+
}
149+
with open(self.pkl_path, 'wb') as f:
150+
pickle.dump(bundle, f)
151+
152+
return self.pkl_path
153+
154+
def load(self) -> None:
155+
"""Load classifier + scaler from .pkl."""
156+
if not os.path.exists(self.pkl_path):
157+
raise FileNotFoundError(
158+
f"No classifier found for task '{self.task}' at {self.pkl_path}\n"
159+
f"Train the classifier first using the appropriate train script."
160+
)
161+
with open(self.pkl_path, 'rb') as f:
162+
bundle = pickle.load(f)
163+
164+
self.clf = bundle['classifier']
165+
self.scaler = bundle['scaler']
166+
self.label_names = bundle.get('label_names', bundle.get('activities', []))
167+
168+
def predict(self, embeddings: np.ndarray) -> tuple[list[str], np.ndarray]:
169+
"""
170+
Predict mental state labels for a batch of embeddings.
171+
172+
Args:
173+
embeddings: (N, 768) numpy array
174+
175+
Returns:
176+
labels: list of N predicted label strings
177+
confidences: (N, n_classes) probability array
178+
"""
179+
if self.clf is None:
180+
self.load()
181+
182+
X = self.scaler.transform(embeddings)
183+
pred_indices = self.clf.predict(X)
184+
probabilities = self.clf.predict_proba(X)
185+
pred_labels = [self.label_names[i] for i in pred_indices]
186+
187+
return pred_labels, probabilities
188+
189+
def predict_majority(self, embeddings: np.ndarray) -> tuple[str, float, np.ndarray]:
190+
"""
191+
Predict a single label for a recording via majority vote across segments.
192+
193+
Args:
194+
embeddings: (N, 768) array for all segments in a recording
195+
196+
Returns:
197+
label: overall predicted label string
198+
confidence: fraction of segments that voted for this label
199+
mean_proba: (n_classes,) mean probability across all segments
200+
"""
201+
labels, probas = self.predict(embeddings)
202+
counts = Counter(labels)
203+
top_label = counts.most_common(1)[0][0]
204+
confidence = counts.most_common(1)[0][1] / len(labels)
205+
mean_proba = probas.mean(axis=0)
206+
207+
return top_label, confidence, mean_proba
208+
209+
210+
def load_classifier(task: str,
211+
models_dir: str = DEFAULT_MODELS_DIR) -> 'MossClassifier':
212+
"""
213+
Convenience function to load a saved classifier by task name.
214+
215+
Args:
216+
task: 'activity', 'focus', 'emotion', or 'stress'
217+
models_dir: directory containing .pkl files
218+
219+
Returns:
220+
loaded MossClassifier ready for prediction
221+
"""
222+
clf = MossClassifier(task=task, label_names=[], models_dir=models_dir)
223+
clf.load()
224+
return clf
225+
226+
227+
if __name__ == '__main__':
228+
# Quick test — load activity classifier and print info
229+
clf = load_classifier('activity')
230+
print(f"Task: {clf.task}")
231+
print(f"Labels: {clf.label_names}")
232+
print(f"Classifier: {clf.clf}")

0 commit comments

Comments
 (0)