@@ -31,7 +31,9 @@ def _build_model(self, n_channels: int, n_labels: int) -> models.Model:
3131 )
3232 return model
3333
34- def _get_model(self, n_channels: int, n_labels: int, blank: bool = False) -> models.Model:
34+ def _get_model(
35+ self, n_channels: int, n_labels: int, blank: bool = False
36+ ) -> models.Model:
3537 model_path = os.path.join(self.MODEL_DIR, "model.keras")
3638 _model = models.load_model(model_path)
3739 _model = self._build_model(n_channels, n_labels)
@@ -135,11 +137,22 @@ def predict(self, tasks: List[Dict], context: Optional[Dict] = None, **kwargs) -
135137 }
136138
137139 def _group_rows(self, df: pd.DataFrame, time_col: str) -> List[Dict]:
138- """ Group consecutive rows with the same predicted label ."""
139- segments = []
140- current = None
141- for _, row in df.iterrows():
142- label = row['pred_label']
140+ def _collect_samples(
141+ self, tasks: List[Dict], params: Dict, label2idx: Dict[str, int]
142+ ) -> Tuple[List, List]:
143+ def predict(
144+ self, tasks: List[Dict], context: Optional[Dict] = None, **kwargs
145+ ) -> ModelResponse:
146+ return ModelResponse(
147+ predictions=predictions, model_version=self.get("model_version")
148+ )
149+ if (
150+ len(tasks) % self.START_TRAINING_EACH_N_UPDATES != 0
151+ and event != "START_TRAINING"
152+ ):
153+ model = self._get_model(
154+ len(params["channels"]), len(params["labels"]), blank=True
155+ )
143156 if current and current['label'] == label:
144157 current['end'] = row[time_col]
145158 current['scores'].append(row['score'])
0 commit comments