|
19 | 19 | InferenceNodeConfig, |
20 | 20 | LoggingConfig, |
21 | 21 | VectorIndexConfig, |
| 22 | + get_default_embedder_config, |
22 | 23 | get_default_vector_index_config, |
23 | 24 | ) |
24 | 25 | from autointent.custom_types import ListOfGenericLabels, NodeType, SearchSpacePreset, SearchSpaceValidationMode |
@@ -56,7 +57,7 @@ def __init__( |
56 | 57 |
|
57 | 58 | if isinstance(nodes[0], NodeOptimizer): |
58 | 59 | self.logging_config = LoggingConfig() |
59 | | - self.embedder_config = EmbedderConfig() |
| 60 | + self.embedder_config = get_default_embedder_config() |
60 | 61 | self.cross_encoder_config = CrossEncoderConfig() |
61 | 62 | self.data_config = DataConfig() |
62 | 63 | self.transformer_config = HFModelConfig() |
@@ -111,7 +112,7 @@ def from_search_space(cls, search_space: list[dict[str, Any]] | Path | str, seed |
111 | 112 | return cls(nodes=nodes, seed=seed) |
112 | 113 |
|
113 | 114 | @classmethod |
114 | | - def from_preset(cls, name: SearchSpacePreset, seed: int | None = 42) -> "Pipeline": |
| 115 | + def from_preset(cls, name: SearchSpacePreset, seed: int = 42) -> "Pipeline": |
115 | 116 | """Instantiate pipeline optimizer from a preset.""" |
116 | 117 | optimization_config = load_preset(name) |
117 | 118 | config = OptimizationConfig(seed=seed, **optimization_config) |
@@ -395,6 +396,19 @@ def _refit(self, context: Context) -> None: |
395 | 396 | decision_module.clear_cache() |
396 | 397 | decision_module.fit(scores, context.data_handler.train_labels(1), context.data_handler.tags) |
397 | 398 |
|
| 399 | + def _convert_score_to_float_list(self, score: Any) -> list[float]: # noqa: ANN401 |
| 400 | + """Convert score to list of floats for InferencePipelineUtteranceOutput.""" |
| 401 | + if hasattr(score, "tolist"): |
| 402 | + result = score.tolist() |
| 403 | + return result if isinstance(result, list) else [float(result)] |
| 404 | + if score is None: |
| 405 | + return [] |
| 406 | + if isinstance(score, int | float): |
| 407 | + return [float(score)] |
| 408 | + if hasattr(score, "__iter__") and not isinstance(score, str): |
| 409 | + return [float(x) for x in score] |
| 410 | + return [float(score)] |
| 411 | + |
398 | 412 | def predict_with_metadata(self, utterances: list[str]) -> InferencePipelineOutput: |
399 | 413 | """Predict the labels for the utterances with metadata. |
400 | 414 |
|
@@ -422,13 +436,13 @@ def predict_with_metadata(self, utterances: list[str]) -> InferencePipelineOutpu |
422 | 436 | regex_prediction_metadata=regex_predictions_metadata[idx] |
423 | 437 | if regex_predictions_metadata is not None |
424 | 438 | else None, |
425 | | - score=scores[idx], |
| 439 | + score=self._convert_score_to_float_list(scores[idx]), |
426 | 440 | score_metadata=scores_metadata[idx] if scores_metadata is not None else None, |
427 | 441 | ) |
428 | 442 | outputs.append(output) |
429 | 443 |
|
430 | 444 | return InferencePipelineOutput( |
431 | | - predictions=predictions, |
| 445 | + predictions=predictions, # type: ignore[arg-type] |
432 | 446 | regex_predictions=regex_predictions, |
433 | 447 | utterances=outputs, |
434 | 448 | ) |
|
0 commit comments