Skip to content

Commit 0567de6

Browse files
committed
Improve timeseries segmenter example
1 parent 38472fb commit 0567de6

File tree

2 files changed

+58
-19
lines changed

2 files changed

+58
-19
lines changed

label_studio_ml/examples/timeseries_segmenter/README.md

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ for new tasks. The backend expects the labeling configuration to use
1818
docker-compose up --build
1919
```
2020

21+
A small example CSV is available in `tests/time_series.csv`.
22+
2123
Connect the model from the **Model** page in your project settings. The default
2224
URL is `http://localhost:9090`.
2325

@@ -38,9 +40,9 @@ Use a configuration similar to the following:
3840
</View>
3941
```
4042

41-
The backend reads the time column and channels to build feature vectors for
42-
training and prediction. Each CSV referenced by `csv_url` is expected to contain
43-
at least the time column and the listed channels.
43+
The backend reads the time column and channels to build feature vectors. Each
44+
CSV referenced by `csv_url` must contain the time column and the channel
45+
columns.
4446

4547
## Training
4648

@@ -54,3 +56,9 @@ fits a logistic regression classifier. Model artifacts are stored in the
5456
For each task, the backend loads the CSV, applies the trained classifier to each
5557
row and groups consecutive predictions into labeled segments. Prediction scores
5658
are averaged per segment and returned to Label Studio.
59+
60+
## Customize
61+
62+
Edit `docker-compose.yml` to set environment variables such as `LABEL_STUDIO_HOST`
63+
or `MODEL_DIR`. You can also adjust `START_TRAINING_EACH_N_UPDATES` to control
64+
how often training runs.

label_studio_ml/examples/timeseries_segmenter/model.py

Lines changed: 47 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,57 +1,66 @@
1+
"""Logistic regression based time series segmenter."""
2+
13
import os
24
import io
35
import pickle
46
import logging
7+
from typing import List, Dict, Optional
8+
59
import pandas as pd
610
import numpy as np
711
import label_studio_sdk
812

9-
from typing import List, Dict, Optional
1013
from sklearn.linear_model import LogisticRegression
1114
from label_studio_ml.model import LabelStudioMLBase
1215
from label_studio_ml.response import ModelResponse
1316

1417
logger = logging.getLogger(__name__)
1518

19+
# Cached model instance to avoid reloading the pickle on each request.
1620
_model: Optional[LogisticRegression] = None
1721

1822

1923
class TimeSeriesSegmenter(LabelStudioMLBase):
20-
"""Simple time series segmentation using logistic regression."""
24+
"""Simple logistic regression based segmenter for time series."""
2125

2226
LABEL_STUDIO_HOST = os.getenv('LABEL_STUDIO_HOST', 'http://localhost:8080')
2327
LABEL_STUDIO_API_KEY = os.getenv('LABEL_STUDIO_API_KEY')
2428
START_TRAINING_EACH_N_UPDATES = int(os.getenv('START_TRAINING_EACH_N_UPDATES', 10))
2529
MODEL_DIR = os.getenv('MODEL_DIR', '.')
2630

2731
def setup(self):
28-
self.set("model_version", f'{self.__class__.__name__}-v0.0.1')
32+
"""Initialize model metadata."""
33+
self.set("model_version", f"{self.__class__.__name__}-v0.0.1")
2934

3035
# util functions
3136
def _get_model(self, blank: bool = False) -> LogisticRegression:
37+
"""Return a trained model or create a new one."""
3238
global _model
3339
if _model is not None and not blank:
3440
return _model
35-
model_path = os.path.join(self.MODEL_DIR, 'model.pkl')
41+
model_path = os.path.join(self.MODEL_DIR, "model.pkl")
3642
if not blank and os.path.exists(model_path):
37-
with open(model_path, 'rb') as f:
43+
with open(model_path, "rb") as f:
3844
_model = pickle.load(f)
3945
else:
4046
_model = LogisticRegression(max_iter=1000)
4147
return _model
4248

4349
def _get_labeling_params(self) -> Dict:
50+
"""Extract tag names and channel info from the labeling config."""
4451
from_name, to_name, value = self.label_interface.get_first_tag_occurence(
45-
'TimeSeriesLabels', 'TimeSeries')
52+
"TimeSeriesLabels", "TimeSeries"
53+
)
4654
tag = self.label_interface.get_tag(from_name)
4755
labels = list(tag.labels)
4856
ts_tag = self.label_interface.get_tag(to_name)
49-
time_col = ts_tag.attr.get('timeColumn')
50-
# parse channel names from the original config since tag doesn't expose children
57+
time_col = ts_tag.attr.get("timeColumn")
58+
# Parse channels from the original XML because the tag does not expose its children
5159
import xml.etree.ElementTree as ET
60+
5261
root = ET.fromstring(self.label_config)
5362
ts_elem = root.find(f".//TimeSeries[@name='{to_name}']")
54-
channels = [ch.attrib['column'] for ch in ts_elem.findall('Channel')]
63+
channels = [ch.attrib["column"] for ch in ts_elem.findall("Channel")]
5564
return {
5665
'from_name': from_name,
5766
'to_name': to_name,
@@ -62,23 +71,30 @@ def _get_labeling_params(self) -> Dict:
6271
}
6372

6473
def _read_csv(self, task: Dict, path: str) -> pd.DataFrame:
74+
"""Load CSV associated with the task from Label Studio."""
6575
csv_str = self.preload_task_data(task, path)
6676
return pd.read_csv(io.StringIO(csv_str))
6777

68-
def predict(self, tasks: List[Dict], context: Optional[Dict] = None, **kwargs) -> ModelResponse:
78+
def predict(
79+
self, tasks: List[Dict], context: Optional[Dict] = None, **kwargs
80+
) -> ModelResponse:
81+
"""Return time series segments predicted for the given tasks."""
6982
params = self._get_labeling_params()
7083
model = self._get_model()
7184
predictions = []
7285
for task in tasks:
73-
df = self._read_csv(task, task['data'][params['value']])
86+
df = self._read_csv(task, task["data"][params["value"]])
87+
# each row is described by the selected channels
7488
X = df[params['channels']].values
7589
if len(X) == 0:
7690
predictions.append({})
7791
continue
92+
# predict probabilities for each label
7893
probs = model.predict_proba(X)
7994
labels_idx = np.argmax(probs, axis=1)
8095
df['pred_label'] = [params['labels'][i] for i in labels_idx]
8196
df['score'] = probs[np.arange(len(probs)), labels_idx]
97+
# group consecutive rows with the same predicted label
8298
segments = []
8399
current = None
84100
for _, row in df.iterrows():
@@ -123,13 +139,21 @@ def predict(self, tasks: List[Dict], context: Optional[Dict] = None, **kwargs) -
123139
return ModelResponse(predictions=predictions, model_version=self.get('model_version'))
124140

125141
def _get_tasks(self, project_id: int) -> List[Dict]:
126-
ls = label_studio_sdk.Client(self.LABEL_STUDIO_HOST, self.LABEL_STUDIO_API_KEY)
142+
"""Fetch labeled tasks from Label Studio."""
143+
ls = label_studio_sdk.Client(
144+
self.LABEL_STUDIO_HOST, self.LABEL_STUDIO_API_KEY
145+
)
127146
project = ls.get_project(id=project_id)
128147
return project.get_labeled_tasks()
129148

130149
def fit(self, event, data, **kwargs):
131-
if event not in ('ANNOTATION_CREATED', 'ANNOTATION_UPDATED', 'START_TRAINING'):
132-
logger.info(f"Skip training: event {event} is not supported")
150+
"""Train the model on all labeled segments."""
151+
if event not in (
152+
"ANNOTATION_CREATED",
153+
"ANNOTATION_UPDATED",
154+
"START_TRAINING",
155+
):
156+
logger.info("Skip training: event %s is not supported", event)
133157
return
134158
project_id = data['annotation']['project']
135159
tasks = self._get_tasks(project_id)
@@ -139,11 +163,12 @@ def fit(self, event, data, **kwargs):
139163
return
140164
params = self._get_labeling_params()
141165
label2idx = {l: i for i, l in enumerate(params['labels'])}
142-
X, y = [], []
166+
X, y = [], [] # features and labels for classifier
143167
for task in tasks:
144168
df = self._read_csv(task, task['data'][params['value']])
145169
if df.empty:
146170
continue
171+
# convert labeled segments into per-row samples
147172
annotations = [a for a in task['annotations'] if a.get('result')]
148173
for ann in annotations:
149174
for r in ann['result']:
@@ -152,8 +177,12 @@ def fit(self, event, data, **kwargs):
152177
start = r['value']['start']
153178
end = r['value']['end']
154179
label = r['value']['timeserieslabels'][0]
155-
mask = (df[params['time_col']] >= start) & (df[params['time_col']] <= end)
180+
mask = (
181+
(df[params['time_col']] >= start)
182+
& (df[params['time_col']] <= end)
183+
)
156184
seg = df.loc[mask, params['channels']].values
185+
# each row inside the labeled range belongs to the segment
157186
X.extend(seg)
158187
y.extend([label2idx[label]] * len(seg))
159188
if not X:
@@ -163,9 +192,11 @@ def fit(self, event, data, **kwargs):
163192
model.fit(np.array(X), np.array(y))
164193
os.makedirs(self.MODEL_DIR, exist_ok=True)
165194
model_path = os.path.join(self.MODEL_DIR, 'model.pkl')
195+
# save trained model to disk
166196
with open(model_path, 'wb') as f:
167197
pickle.dump(model, f)
168198
global _model
199+
# reload the cached model on next prediction
169200
_model = None
170201
self._get_model()
171202

0 commit comments

Comments
 (0)