Skip to content

Commit f20408f

Browse files
committed
Format timeseries segmenter model
1 parent 651cabb commit f20408f

File tree

1 file changed

+19
-6
lines changed
  • label_studio_ml/examples/timeseries_segmenter

1 file changed

+19
-6
lines changed

label_studio_ml/examples/timeseries_segmenter/model.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,9 @@ def _build_model(self, n_channels: int, n_labels: int) -> models.Model:
3131
)
3232
return model
3333
34-
def _get_model(self, n_channels: int, n_labels: int, blank: bool = False) -> models.Model:
34+
def _get_model(
35+
self, n_channels: int, n_labels: int, blank: bool = False
36+
) -> models.Model:
3537
model_path = os.path.join(self.MODEL_DIR, "model.keras")
3638
_model = models.load_model(model_path)
3739
_model = self._build_model(n_channels, n_labels)
@@ -135,11 +137,22 @@ def predict(self, tasks: List[Dict], context: Optional[Dict] = None, **kwargs) -
135137
}
136138
137139
def _group_rows(self, df: pd.DataFrame, time_col: str) -> List[Dict]:
138-
"""Group consecutive rows with the same predicted label."""
139-
segments = []
140-
current = None
141-
for _, row in df.iterrows():
142-
label = row['pred_label']
140+
def _collect_samples(
141+
self, tasks: List[Dict], params: Dict, label2idx: Dict[str, int]
142+
) -> Tuple[List, List]:
143+
def predict(
144+
self, tasks: List[Dict], context: Optional[Dict] = None, **kwargs
145+
) -> ModelResponse:
146+
return ModelResponse(
147+
predictions=predictions, model_version=self.get("model_version")
148+
)
149+
if (
150+
len(tasks) % self.START_TRAINING_EACH_N_UPDATES != 0
151+
and event != "START_TRAINING"
152+
):
153+
model = self._get_model(
154+
len(params["channels"]), len(params["labels"]), blank=True
155+
)
143156
if current and current['label'] == label:
144157
current['end'] = row[time_col]
145158
current['scores'].append(row['score'])

0 commit comments

Comments
 (0)