1+ """Logistic regression based time series segmenter."""
2+
13import os
24import io
35import pickle
46import logging
7+ from typing import List , Dict , Optional
8+
59import pandas as pd
610import numpy as np
711import label_studio_sdk
812
9- from typing import List , Dict , Optional
1013from sklearn .linear_model import LogisticRegression
1114from label_studio_ml .model import LabelStudioMLBase
1215from label_studio_ml .response import ModelResponse
1316
1417logger = logging .getLogger (__name__ )
1518
19+ # Cached model instance to avoid reloading the pickle on each request.
1620_model : Optional [LogisticRegression ] = None
1721
1822
1923class TimeSeriesSegmenter (LabelStudioMLBase ):
20- """Simple time series segmentation using logistic regression ."""
24+ """Simple logistic regression based segmenter for time series ."""
2125
2226 LABEL_STUDIO_HOST = os .getenv ('LABEL_STUDIO_HOST' , 'http://localhost:8080' )
2327 LABEL_STUDIO_API_KEY = os .getenv ('LABEL_STUDIO_API_KEY' )
2428 START_TRAINING_EACH_N_UPDATES = int (os .getenv ('START_TRAINING_EACH_N_UPDATES' , 10 ))
2529 MODEL_DIR = os .getenv ('MODEL_DIR' , '.' )
2630
2731 def setup (self ):
28- self .set ("model_version" , f'{ self .__class__ .__name__ } -v0.0.1' )
32+ """Initialize model metadata."""
33+ self .set ("model_version" , f"{ self .__class__ .__name__ } -v0.0.1" )
2934
3035 # util functions
3136 def _get_model (self , blank : bool = False ) -> LogisticRegression :
37+ """Return a trained model or create a new one."""
3238 global _model
3339 if _model is not None and not blank :
3440 return _model
35- model_path = os .path .join (self .MODEL_DIR , ' model.pkl' )
41+ model_path = os .path .join (self .MODEL_DIR , " model.pkl" )
3642 if not blank and os .path .exists (model_path ):
37- with open (model_path , 'rb' ) as f :
43+ with open (model_path , "rb" ) as f :
3844 _model = pickle .load (f )
3945 else :
4046 _model = LogisticRegression (max_iter = 1000 )
4147 return _model
4248
4349 def _get_labeling_params (self ) -> Dict :
50+ """Extract tag names and channel info from the labeling config."""
4451 from_name , to_name , value = self .label_interface .get_first_tag_occurence (
45- 'TimeSeriesLabels' , 'TimeSeries' )
52+ "TimeSeriesLabels" , "TimeSeries"
53+ )
4654 tag = self .label_interface .get_tag (from_name )
4755 labels = list (tag .labels )
4856 ts_tag = self .label_interface .get_tag (to_name )
49- time_col = ts_tag .attr .get (' timeColumn' )
50- # parse channel names from the original config since tag doesn't expose children
57+ time_col = ts_tag .attr .get (" timeColumn" )
58+ # Parse channels from the original XML because the tag does not expose its children
5159 import xml .etree .ElementTree as ET
60+
5261 root = ET .fromstring (self .label_config )
5362 ts_elem = root .find (f".//TimeSeries[@name='{ to_name } ']" )
54- channels = [ch .attrib [' column' ] for ch in ts_elem .findall (' Channel' )]
63+ channels = [ch .attrib [" column" ] for ch in ts_elem .findall (" Channel" )]
5564 return {
5665 'from_name' : from_name ,
5766 'to_name' : to_name ,
@@ -62,23 +71,30 @@ def _get_labeling_params(self) -> Dict:
6271 }
6372
6473 def _read_csv (self , task : Dict , path : str ) -> pd .DataFrame :
74+ """Load CSV associated with the task from Label Studio."""
6575 csv_str = self .preload_task_data (task , path )
6676 return pd .read_csv (io .StringIO (csv_str ))
6777
68- def predict (self , tasks : List [Dict ], context : Optional [Dict ] = None , ** kwargs ) -> ModelResponse :
78+ def predict (
79+ self , tasks : List [Dict ], context : Optional [Dict ] = None , ** kwargs
80+ ) -> ModelResponse :
81+ """Return time series segments predicted for the given tasks."""
6982 params = self ._get_labeling_params ()
7083 model = self ._get_model ()
7184 predictions = []
7285 for task in tasks :
73- df = self ._read_csv (task , task ['data' ][params ['value' ]])
86+ df = self ._read_csv (task , task ["data" ][params ["value" ]])
87+ # each row is described by the selected channels
7488 X = df [params ['channels' ]].values
7589 if len (X ) == 0 :
7690 predictions .append ({})
7791 continue
92+ # predict probabilities for each label
7893 probs = model .predict_proba (X )
7994 labels_idx = np .argmax (probs , axis = 1 )
8095 df ['pred_label' ] = [params ['labels' ][i ] for i in labels_idx ]
8196 df ['score' ] = probs [np .arange (len (probs )), labels_idx ]
97+ # group consecutive rows with the same predicted label
8298 segments = []
8399 current = None
84100 for _ , row in df .iterrows ():
@@ -123,13 +139,21 @@ def predict(self, tasks: List[Dict], context: Optional[Dict] = None, **kwargs) -
123139 return ModelResponse (predictions = predictions , model_version = self .get ('model_version' ))
124140
125141 def _get_tasks (self , project_id : int ) -> List [Dict ]:
126- ls = label_studio_sdk .Client (self .LABEL_STUDIO_HOST , self .LABEL_STUDIO_API_KEY )
142+ """Fetch labeled tasks from Label Studio."""
143+ ls = label_studio_sdk .Client (
144+ self .LABEL_STUDIO_HOST , self .LABEL_STUDIO_API_KEY
145+ )
127146 project = ls .get_project (id = project_id )
128147 return project .get_labeled_tasks ()
129148
130149 def fit (self , event , data , ** kwargs ):
131- if event not in ('ANNOTATION_CREATED' , 'ANNOTATION_UPDATED' , 'START_TRAINING' ):
132- logger .info (f"Skip training: event { event } is not supported" )
150+ """Train the model on all labeled segments."""
151+ if event not in (
152+ "ANNOTATION_CREATED" ,
153+ "ANNOTATION_UPDATED" ,
154+ "START_TRAINING" ,
155+ ):
156+ logger .info ("Skip training: event %s is not supported" , event )
133157 return
134158 project_id = data ['annotation' ]['project' ]
135159 tasks = self ._get_tasks (project_id )
@@ -139,11 +163,12 @@ def fit(self, event, data, **kwargs):
139163 return
140164 params = self ._get_labeling_params ()
141165 label2idx = {l : i for i , l in enumerate (params ['labels' ])}
142- X , y = [], []
166+ X , y = [], [] # features and labels for classifier
143167 for task in tasks :
144168 df = self ._read_csv (task , task ['data' ][params ['value' ]])
145169 if df .empty :
146170 continue
171+ # convert labeled segments into per-row samples
147172 annotations = [a for a in task ['annotations' ] if a .get ('result' )]
148173 for ann in annotations :
149174 for r in ann ['result' ]:
@@ -152,8 +177,12 @@ def fit(self, event, data, **kwargs):
152177 start = r ['value' ]['start' ]
153178 end = r ['value' ]['end' ]
154179 label = r ['value' ]['timeserieslabels' ][0 ]
155- mask = (df [params ['time_col' ]] >= start ) & (df [params ['time_col' ]] <= end )
180+ mask = (
181+ (df [params ['time_col' ]] >= start )
182+ & (df [params ['time_col' ]] <= end )
183+ )
156184 seg = df .loc [mask , params ['channels' ]].values
185+ # each row inside the labeled range belongs to the segment
157186 X .extend (seg )
158187 y .extend ([label2idx [label ]] * len (seg ))
159188 if not X :
@@ -163,9 +192,11 @@ def fit(self, event, data, **kwargs):
163192 model .fit (np .array (X ), np .array (y ))
164193 os .makedirs (self .MODEL_DIR , exist_ok = True )
165194 model_path = os .path .join (self .MODEL_DIR , 'model.pkl' )
195+ # save trained model to disk
166196 with open (model_path , 'wb' ) as f :
167197 pickle .dump (model , f )
168198 global _model
199+ # reload the cached model on next prediction
169200 _model = None
170201 self ._get_model ()
171202
0 commit comments