|
| 1 | +""" |
| 2 | +MOSS - classifier.py |
| 3 | +==================== |
| 4 | +Handles MLP classifier training, saving, loading, and prediction. |
| 5 | +
|
| 6 | +Input: (N, 768) numpy embeddings from encoder.py |
| 7 | +Output: predicted labels + confidence scores |
| 8 | +
|
| 9 | +Each task (activity, focus, emotion, stress) has its own saved .pkl file. |
| 10 | +New tasks can be added by training a new classifier on embeddings for that task. |
| 11 | +
|
| 12 | +Used by: coordinator.py |
| 13 | +""" |
| 14 | + |
| 15 | +import os |
| 16 | +import pickle |
| 17 | +import numpy as np |
| 18 | +from typing import Optional |
| 19 | +from sklearn.preprocessing import StandardScaler |
| 20 | +from sklearn.neural_network import MLPClassifier |
| 21 | +from sklearn.model_selection import StratifiedKFold |
| 22 | +from sklearn.metrics import accuracy_score, balanced_accuracy_score |
| 23 | +from sklearn.utils.class_weight import compute_sample_weight |
| 24 | +from collections import Counter |
| 25 | + |
| 26 | +# ── Default paths ────────────────────────────────────────────────────────────── |
| 27 | +DEFAULT_MODELS_DIR = os.path.join(os.path.dirname(__file__), 'moss_models') |
| 28 | + |
| 29 | +# ── Task → classifier file mapping ──────────────────────────────────────────── |
| 30 | +TASK_CLASSIFIER_MAP = { |
| 31 | + 'activity': 'muse2_classifier.pkl', |
| 32 | + 'focus': 'focus_classifier.pkl', |
| 33 | + 'emotion': 'emotion_classifier.pkl', |
| 34 | + 'stress': 'stress_classifier.pkl', |
| 35 | +} |
| 36 | + |
| 37 | + |
| 38 | +class MossClassifier: |
| 39 | + """ |
| 40 | + Thin wrapper around sklearn MLP for MOSS mental state classification. |
| 41 | +
|
| 42 | + Handles: |
| 43 | + - Training with optional class balancing |
| 44 | + - Saving/loading to .pkl |
| 45 | + - Predicting labels + confidence scores from embeddings |
| 46 | + """ |
| 47 | + |
| 48 | + def __init__(self, |
| 49 | + task: str, |
| 50 | + label_names: list[str], |
| 51 | + models_dir: str = DEFAULT_MODELS_DIR): |
| 52 | + """ |
| 53 | + Args: |
| 54 | + task: task name (e.g. 'activity', 'focus', 'emotion') |
| 55 | + label_names: ordered list of class names (index = class id) |
| 56 | + models_dir: directory where .pkl files are saved/loaded |
| 57 | + """ |
| 58 | + self.task = task |
| 59 | + self.label_names = label_names |
| 60 | + self.models_dir = models_dir |
| 61 | + self.clf = None |
| 62 | + self.scaler = None |
| 63 | + os.makedirs(models_dir, exist_ok=True) |
| 64 | + |
| 65 | + @property |
| 66 | + def pkl_path(self) -> str: |
| 67 | + filename = TASK_CLASSIFIER_MAP.get(self.task, f'{self.task}_classifier.pkl') |
| 68 | + return os.path.join(self.models_dir, filename) |
| 69 | + |
| 70 | + def train(self, |
| 71 | + embeddings: np.ndarray, |
| 72 | + labels: np.ndarray, |
| 73 | + balance_classes: bool = True, |
| 74 | + n_splits: int = 5) -> dict: |
| 75 | + """ |
| 76 | + Train MLP classifier on embeddings with optional k-fold CV evaluation. |
| 77 | +
|
| 78 | + Args: |
| 79 | + embeddings: (N, 768) array |
| 80 | + labels: (N,) integer class labels |
| 81 | + balance_classes: use sample weights to handle class imbalance |
| 82 | + n_splits: number of CV folds (set to 0 to skip CV) |
| 83 | +
|
| 84 | + Returns: |
| 85 | + results: dict with accuracy, balanced_accuracy, per-fold scores |
| 86 | + """ |
| 87 | + results = {} |
| 88 | + |
| 89 | + # Optional cross-validation evaluation |
| 90 | + if n_splits > 1: |
| 91 | + skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=42) |
| 92 | + fold_accs, fold_bals = [], [] |
| 93 | + |
| 94 | + for tr, te in skf.split(embeddings, labels): |
| 95 | + scaler = StandardScaler() |
| 96 | + X_tr = scaler.fit_transform(embeddings[tr]) |
| 97 | + X_te = scaler.transform(embeddings[te]) |
| 98 | + |
| 99 | + clf = self._make_mlp() |
| 100 | + sw = compute_sample_weight('balanced', labels[tr]) if balance_classes else None |
| 101 | + clf.fit(X_tr, labels[tr], sw) |
| 102 | + |
| 103 | + preds = clf.predict(X_te) |
| 104 | + fold_accs.append(accuracy_score(labels[te], preds)) |
| 105 | + fold_bals.append(balanced_accuracy_score(labels[te], preds)) |
| 106 | + |
| 107 | + results['cv_accuracy'] = float(np.mean(fold_accs)) |
| 108 | + results['cv_balanced_accuracy'] = float(np.mean(fold_bals)) |
| 109 | + results['cv_fold_accuracies'] = [float(x) for x in fold_accs] |
| 110 | + |
| 111 | + # Train final classifier on all data |
| 112 | + self.scaler = StandardScaler() |
| 113 | + X_all = self.scaler.fit_transform(embeddings) |
| 114 | + self.clf = self._make_mlp() |
| 115 | + sw = compute_sample_weight('balanced', labels) if balance_classes else None |
| 116 | + self.clf.fit(X_all, labels, sw) |
| 117 | + |
| 118 | + results['n_samples'] = len(labels) |
| 119 | + results['n_classes'] = len(np.unique(labels)) |
| 120 | + results['label_names'] = self.label_names |
| 121 | + results['class_distribution'] = { |
| 122 | + self.label_names[k]: int(v) |
| 123 | + for k, v in sorted(Counter(labels).items()) |
| 124 | + } |
| 125 | + |
| 126 | + return results |
| 127 | + |
| 128 | + def _make_mlp(self) -> MLPClassifier: |
| 129 | + return MLPClassifier( |
| 130 | + hidden_layer_sizes=(256, 128), |
| 131 | + max_iter=500, |
| 132 | + random_state=42, |
| 133 | + early_stopping=True, |
| 134 | + n_iter_no_change=20 |
| 135 | + ) |
| 136 | + |
| 137 | + def save(self) -> str: |
| 138 | + """Save trained classifier + scaler to .pkl. Returns path.""" |
| 139 | + if self.clf is None or self.scaler is None: |
| 140 | + raise RuntimeError("Classifier not trained yet. Call train() first.") |
| 141 | + |
| 142 | + bundle = { |
| 143 | + 'classifier': self.clf, |
| 144 | + 'scaler': self.scaler, |
| 145 | + 'label_names': self.label_names, |
| 146 | + 'activities': self.label_names, # kept for predict.py compatibility |
| 147 | + 'task': self.task, |
| 148 | + } |
| 149 | + with open(self.pkl_path, 'wb') as f: |
| 150 | + pickle.dump(bundle, f) |
| 151 | + |
| 152 | + return self.pkl_path |
| 153 | + |
| 154 | + def load(self) -> None: |
| 155 | + """Load classifier + scaler from .pkl.""" |
| 156 | + if not os.path.exists(self.pkl_path): |
| 157 | + raise FileNotFoundError( |
| 158 | + f"No classifier found for task '{self.task}' at {self.pkl_path}\n" |
| 159 | + f"Train the classifier first using the appropriate train script." |
| 160 | + ) |
| 161 | + with open(self.pkl_path, 'rb') as f: |
| 162 | + bundle = pickle.load(f) |
| 163 | + |
| 164 | + self.clf = bundle['classifier'] |
| 165 | + self.scaler = bundle['scaler'] |
| 166 | + self.label_names = bundle.get('label_names', bundle.get('activities', [])) |
| 167 | + |
| 168 | + def predict(self, embeddings: np.ndarray) -> tuple[list[str], np.ndarray]: |
| 169 | + """ |
| 170 | + Predict mental state labels for a batch of embeddings. |
| 171 | +
|
| 172 | + Args: |
| 173 | + embeddings: (N, 768) numpy array |
| 174 | +
|
| 175 | + Returns: |
| 176 | + labels: list of N predicted label strings |
| 177 | + confidences: (N, n_classes) probability array |
| 178 | + """ |
| 179 | + if self.clf is None: |
| 180 | + self.load() |
| 181 | + |
| 182 | + X = self.scaler.transform(embeddings) |
| 183 | + pred_indices = self.clf.predict(X) |
| 184 | + probabilities = self.clf.predict_proba(X) |
| 185 | + pred_labels = [self.label_names[i] for i in pred_indices] |
| 186 | + |
| 187 | + return pred_labels, probabilities |
| 188 | + |
| 189 | + def predict_majority(self, embeddings: np.ndarray) -> tuple[str, float, np.ndarray]: |
| 190 | + """ |
| 191 | + Predict a single label for a recording via majority vote across segments. |
| 192 | +
|
| 193 | + Args: |
| 194 | + embeddings: (N, 768) array for all segments in a recording |
| 195 | +
|
| 196 | + Returns: |
| 197 | + label: overall predicted label string |
| 198 | + confidence: fraction of segments that voted for this label |
| 199 | + mean_proba: (n_classes,) mean probability across all segments |
| 200 | + """ |
| 201 | + labels, probas = self.predict(embeddings) |
| 202 | + counts = Counter(labels) |
| 203 | + top_label = counts.most_common(1)[0][0] |
| 204 | + confidence = counts.most_common(1)[0][1] / len(labels) |
| 205 | + mean_proba = probas.mean(axis=0) |
| 206 | + |
| 207 | + return top_label, confidence, mean_proba |
| 208 | + |
| 209 | + |
| 210 | +def load_classifier(task: str, |
| 211 | + models_dir: str = DEFAULT_MODELS_DIR) -> 'MossClassifier': |
| 212 | + """ |
| 213 | + Convenience function to load a saved classifier by task name. |
| 214 | +
|
| 215 | + Args: |
| 216 | + task: 'activity', 'focus', 'emotion', or 'stress' |
| 217 | + models_dir: directory containing .pkl files |
| 218 | +
|
| 219 | + Returns: |
| 220 | + loaded MossClassifier ready for prediction |
| 221 | + """ |
| 222 | + clf = MossClassifier(task=task, label_names=[], models_dir=models_dir) |
| 223 | + clf.load() |
| 224 | + return clf |
| 225 | + |
| 226 | + |
| 227 | +if __name__ == '__main__': |
| 228 | + # Quick test — load activity classifier and print info |
| 229 | + clf = load_classifier('activity') |
| 230 | + print(f"Task: {clf.task}") |
| 231 | + print(f"Labels: {clf.label_names}") |
| 232 | + print(f"Classifier: {clf.clf}") |
0 commit comments