Skip to content

Commit 80e7b2a

Browse files
committed
Fixes
1 parent beba194 commit 80e7b2a

File tree

3 files changed

+161
-80
lines changed

3 files changed

+161
-80
lines changed

.rules/new_models_best_practice.mdc

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ Each example should contain the following files:
3333

3434
- Reference the main repository README to help users understand how to install and run the ML backend.
3535
- Include labeling configuration examples in the example README so users can quickly reproduce training and inference.
36-
- Provide troubleshooting tips or links to Label Studio documentation such as [Writing your own ML backend](https://labelstud.io/guide/ml_create).
36+
- Provide troubleshooting tips or links to Label Studio documentation such as [Writing your own ML backend](mdc:https:/labelstud.io/guide/ml_create).
3737

3838
## 3.1. Security Best Practices
3939

@@ -106,8 +106,6 @@ def get_model(self, model_path):
106106
Implement proper model versioning for production systems:
107107

108108
- **Version Tracking**: Include model version in predictions and logs
109-
- **Backwards Compatibility**: Handle multiple model versions gracefully
110-
- **Migration Strategies**: Provide clear upgrade paths for model updates
111109
- **Rollback Support**: Maintain ability to revert to previous model versions
112110

113111
Example versioning pattern:
@@ -148,6 +146,7 @@ Implement robust data processing for different scenarios:
148146
- **Type Safety**: Use proper type conversion and validation for different data types
149147
- **Streaming Data**: Support large files that don't fit in memory using streaming approaches
150148
- **Data Caching**: Cache preprocessed data when appropriate to improve performance
149+
- **LabelStudioMLBackend::preload_task_data(path, task)**: Download all URLs from a task and stores them locally. It uses get_local_path() `from label_studio_sdk._extensions.label_studio_tools.core.utils.io import get_local_path` and requires LABEL_STUDIO_API_KEY and LABEL_STUDIO_URL to be able to download files through Label Studio instance.
151150

152151
Example robust data loading:
153152
```python

label_studio_ml/examples/timeseries_segmenter/README.md

Lines changed: 88 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -46,57 +46,114 @@ columns.
4646

4747
## Training
4848

49-
Training starts automatically when annotations are created or updated. The model
50-
collects all labeled segments, extracts sensor values inside each segment and
51-
fits an LSTM classifier. Model artifacts are stored in the
52-
`MODEL_DIR` (defaults to the current directory).
49+
Training starts automatically when annotations are created or updated. The model uses a PyTorch-based LSTM neural network with proper temporal modeling to learn time series patterns.
5350

54-
Steps performed by `fit()`:
51+
### Training Process
5552

56-
1. Fetch all labeled tasks from Label Studio.
57-
2. Convert labeled ranges to per-row training samples.
58-
3. Fit a small LSTM network.
59-
4. Save the trained model to disk.
53+
The model follows these steps during training:
54+
55+
1. **Data Collection**: Fetches all labeled tasks from your Label Studio project
56+
2. **Sample Generation**: Converts labeled time ranges into training samples:
57+
- **Background Class**: Unlabeled time periods are treated as "background" (class 0)
58+
- **Event Classes**: Your labeled segments (e.g., "Run", "Walk") become classes 1, 2, etc.
59+
- **Ground Truth Priority**: If multiple annotations exist for a task, ground truth annotations take precedence
60+
3. **Model Training**: Fits a multi-layer LSTM network with:
61+
- Configurable sequence windows (default: 50 timesteps)
62+
- Dropout regularization for better generalization
63+
- Background class support for realistic time series modeling
64+
4. **Model Persistence**: Saves trained model artifacts to `MODEL_DIR`
65+
66+
### Training Configuration
67+
68+
You can customize training behavior with these environment variables:
69+
70+
- `START_TRAINING_EACH_N_UPDATES`: How often to retrain (default: 1, trains on every annotation)
71+
- `TRAIN_EPOCHS`: Number of training epochs (default: 1000)
72+
- `SEQUENCE_SIZE`: Sliding window size for temporal context (default: 50)
73+
- `HIDDEN_SIZE`: LSTM hidden layer size (default: 64)
74+
75+
### Ground Truth Handling
76+
77+
When multiple annotations exist for the same task, the model prioritizes ground truth annotations:
78+
- Non-ground truth annotations are processed first
79+
- Ground truth annotations override previous labels and stop processing for that task
80+
- This ensures the highest quality labels are used for training
6081

6182
## Prediction
6283

63-
For each task, the backend loads the CSV, applies the trained classifier to each
64-
row and groups consecutive predictions into labeled segments. Prediction scores
65-
are averaged per segment and returned to Label Studio.
84+
The model processes new time series data by applying the trained LSTM classifier with sliding window temporal context. Only meaningful event segments are returned to Label Studio, filtering out background periods automatically.
85+
86+
### Prediction Process
6687

67-
The `predict()` method:
88+
For each task, the model performs these steps:
6889

69-
1. Loads the stored model.
70-
2. Reads the task CSV and builds a feature matrix.
71-
3. Predicts a label for each row.
72-
4. Merges consecutive rows with the same label into a segment.
73-
5. Returns the segments in Label Studio JSON format.
90+
1. **Model Loading**: Loads the trained PyTorch model from disk
91+
2. **Data Processing**: Reads the task CSV and creates feature vectors from sensor channels
92+
3. **Temporal Prediction**: Applies LSTM with sliding windows for temporal context:
93+
- Uses overlapping windows with 50% overlap for smoother predictions
94+
- Averages predictions across overlapping windows
95+
- Maintains temporal dependencies between timesteps
96+
4. **Segment Extraction**: Groups consecutive predictions into meaningful segments:
97+
- **Background Filtering**: Automatically filters out background (unlabeled) periods
98+
- **Event Segmentation**: Only returns segments with actual event labels
99+
- **Score Calculation**: Averages prediction confidence per segment
100+
5. **Result Formatting**: Returns segments in Label Studio JSON format
101+
102+
### Prediction Quality
103+
104+
The model provides several quality indicators:
105+
106+
- **Per-segment Confidence**: Average prediction probability for each returned segment
107+
- **Temporal Consistency**: Sliding window approach reduces prediction noise
108+
- **Background Suppression**: Only returns segments where the model is confident about specific events
109+
110+
This approach ensures that predictions focus on actual events rather than forcing labels on every timestep.
74111

75112
## How it works
76113

77-
### Training pipeline
114+
### Training Pipeline
78115

79116
```mermaid
80117
flowchart TD
81-
A[Webhook event] --> B{Enough tasks?}
82-
B -- no --> C[Skip]
83-
B -- yes --> D[Load labeled tasks]
84-
D --> E[Collect per-row samples]
85-
E --> F[Fit LSTM]
86-
F --> G[Save model]
118+
A[Annotation Event] --> B{Training Trigger?}
119+
B -- no --> C[Skip Training]
120+
B -- yes --> D[Fetch Labeled Tasks]
121+
D --> E[Process Annotations]
122+
E --> F{Ground Truth?}
123+
F -- yes --> G[Priority Processing]
124+
F -- no --> H[Standard Processing]
125+
G --> I[Generate Samples]
126+
H --> I
127+
I --> J[Background + Event Labels]
128+
J --> K[PyTorch LSTM Training]
129+
K --> L[Model Validation]
130+
L --> M[Save Model]
131+
M --> N[Cache in Memory]
87132
```
88133

89-
### Prediction pipeline
134+
### Prediction Pipeline
90135

91136
```mermaid
92137
flowchart TD
93-
T[Predict request] --> U[Load model]
94-
U --> V[Read task CSV]
95-
V --> W[Predict label per row]
96-
W --> X[Group consecutive labels]
97-
X --> Y[Return segments]
138+
T[Prediction Request] --> U[Load PyTorch Model]
139+
U --> V[Read Task CSV]
140+
V --> W[Extract Features]
141+
W --> X[Sliding Window LSTM]
142+
X --> Y[Overlap Averaging]
143+
Y --> Z[Filter Background]
144+
Z --> AA[Group Event Segments]
145+
AA --> BB[Calculate Confidence]
146+
BB --> CC[Return Segments]
98147
```
99148

149+
### Key Technical Features
150+
151+
- **PyTorch-based LSTM**: Modern deep learning framework with better performance and flexibility
152+
- **Temporal Modeling**: Sliding windows capture time dependencies (default 50 timesteps)
153+
- **Background Class**: Realistic modeling where unlabeled periods are explicit background
154+
- **Ground Truth Priority**: Ensures highest quality annotations are used for training
155+
- **Overlap Averaging**: Smoother predictions through overlapping window consensus
156+
100157
## Customize
101158

102159
Edit `docker-compose.yml` to set environment variables such as `LABEL_STUDIO_HOST`

label_studio_ml/examples/timeseries_segmenter/model.py

Lines changed: 71 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,75 @@ def _group_rows(self, df: pd.DataFrame, time_col: str) -> List[Dict]:
236236
logger.debug(f"Grouped into {len(segments)} segments")
237237
return segments
238238

239+
def _process_task_annotations(
240+
self, task: Dict, df: pd.DataFrame, params: Dict, label2idx: Dict[str, int]
241+
) -> Tuple[np.ndarray, int]:
242+
"""Process annotations for a single task and return row labels.
243+
244+
Args:
245+
task: Label Studio task dictionary
246+
df: DataFrame with time series data
247+
params: Labeling parameters from label config
248+
label2idx: Mapping from label names to indices
249+
250+
Returns:
251+
Tuple of (row_labels array, number of labeled rows)
252+
"""
253+
task_id = task.get("id", "unknown")
254+
255+
# Initialize all rows as background (index 0)
256+
row_labels = np.zeros(len(df), dtype=np.int64) # 0 = background
257+
258+
annotations = [a for a in task["annotations"] if a.get("result")]
259+
logger.debug(f"Task {task_id}: Found {len(annotations)} annotations")
260+
261+
# Mark labeled regions
262+
labeled_rows = 0
263+
for ann in annotations:
264+
for r in ann["result"]:
265+
if r["from_name"] != params["from_name"]:
266+
continue
267+
start = r["value"]["start"]
268+
end = r["value"]["end"]
269+
label = r["value"]["timeserieslabels"][0]
270+
271+
# Convert start/end to same type as time column for comparison
272+
time_dtype = df[params["time_col"]].dtype
273+
logger.debug(f"Task {task_id}: Converting time range [{start}, {end}] to match column dtype {time_dtype}")
274+
try:
275+
if 'int' in str(time_dtype):
276+
start = int(float(start))
277+
end = int(float(end))
278+
elif 'float' in str(time_dtype):
279+
start = float(start)
280+
end = float(end)
281+
# For string/datetime, keep as is
282+
logger.debug(f"Task {task_id}: Converted to [{start}, {end}]")
283+
except (ValueError, TypeError) as e:
284+
logger.warning(f"Could not convert start={start}, end={end} to {time_dtype}: {e}, using original values")
285+
286+
# Find rows in this time range
287+
try:
288+
mask = (df[params["time_col"]] >= start) & (
289+
df[params["time_col"]] <= end
290+
)
291+
except TypeError as e:
292+
logger.error(f"Task {task_id}: Type error comparing times - start={start} ({type(start)}), end={end} ({type(end)}), time_col dtype={time_dtype}: {e}")
293+
# Skip this annotation if we can't compare
294+
continue
295+
296+
# Set the appropriate label index
297+
label_idx = label2idx[label]
298+
row_labels[mask] = label_idx
299+
labeled_rows += mask.sum()
300+
logger.debug(f"Task {task_id}: Labeled {mask.sum()} rows with '{label}' (index {label_idx})")
301+
302+
if ann.get('ground_truth', False):
303+
logger.info(f"Task {task_id}: Ground truth annotation found: {ann['ground_truth']}")
304+
break
305+
306+
return row_labels, labeled_rows
307+
239308
def _collect_samples(
240309
self, tasks: List[Dict], params: Dict, label2idx: Dict[str, int]
241310
) -> Tuple[np.ndarray, np.ndarray]:
@@ -253,52 +322,8 @@ def _collect_samples(
253322
logger.warning(f"Task {task_id}: Empty dataframe, skipping")
254323
continue
255324

256-
# Initialize all rows as background (index 0)
257-
row_labels = np.zeros(len(df), dtype=np.int64) # 0 = background
258-
259-
annotations = [a for a in task["annotations"] if a.get("result")]
260-
logger.debug(f"Task {task_id}: Found {len(annotations)} annotations")
261-
262-
# Mark labeled regions
263-
labeled_rows = 0
264-
for ann in annotations:
265-
for r in ann["result"]:
266-
if r["from_name"] != params["from_name"]:
267-
continue
268-
start = r["value"]["start"]
269-
end = r["value"]["end"]
270-
label = r["value"]["timeserieslabels"][0]
271-
272-
# Convert start/end to same type as time column for comparison
273-
time_dtype = df[params["time_col"]].dtype
274-
logger.debug(f"Task {task_id}: Converting time range [{start}, {end}] to match column dtype {time_dtype}")
275-
try:
276-
if 'int' in str(time_dtype):
277-
start = int(float(start))
278-
end = int(float(end))
279-
elif 'float' in str(time_dtype):
280-
start = float(start)
281-
end = float(end)
282-
# For string/datetime, keep as is
283-
logger.debug(f"Task {task_id}: Converted to [{start}, {end}]")
284-
except (ValueError, TypeError) as e:
285-
logger.warning(f"Could not convert start={start}, end={end} to {time_dtype}: {e}, using original values")
286-
287-
# Find rows in this time range
288-
try:
289-
mask = (df[params["time_col"]] >= start) & (
290-
df[params["time_col"]] <= end
291-
)
292-
except TypeError as e:
293-
logger.error(f"Task {task_id}: Type error comparing times - start={start} ({type(start)}), end={end} ({type(end)}), time_col dtype={time_dtype}: {e}")
294-
# Skip this annotation if we can't compare
295-
continue
296-
297-
# Set the appropriate label index
298-
label_idx = label2idx[label]
299-
row_labels[mask] = label_idx
300-
labeled_rows += mask.sum()
301-
logger.debug(f"Task {task_id}: Labeled {mask.sum()} rows with '{label}' (index {label_idx})")
325+
# Process annotations for this task
326+
row_labels, labeled_rows = self._process_task_annotations(task, df, params, label2idx)
302327

303328
# Add ALL rows to training data
304329
X_list.append(df[params["channels"]].values.astype(np.float32))

0 commit comments

Comments
 (0)