22from typing import Any , Literal
33
44from pydantic import BaseModel , ConfigDict , Field , PositiveInt
5- from typing_extensions import Self , assert_never
5+ from typing_extensions import Self
66
77from autointent .custom_types import FloatFromZeroToOne
88from autointent .metrics import SCORING_METRICS_MULTICLASS , SCORING_METRICS_MULTILABEL
@@ -91,7 +91,7 @@ def get_prompt_config(self) -> dict[str, str] | None:
9191 prompts [TaskTypeEnum .sts .value ] = self .sts_prompt
9292 return prompts if len (prompts ) > 0 else None
9393
94- def get_prompt (self , prompt_type : TaskTypeEnum | None ) -> str | None : # noqa: PLR0911
94+ def get_prompt (self , prompt_type : TaskTypeEnum | None ) -> str | None :
9595 """Get the prompt type for the given task type.
9696
9797 Args:
@@ -100,21 +100,17 @@ def get_prompt(self, prompt_type: TaskTypeEnum | None) -> str | None: # noqa: P
100100 Returns:
101101 The prompt for the given task type.
102102 """
103- if prompt_type is None :
104- return self .default_prompt
105- if prompt_type == TaskTypeEnum .classification :
103+ if prompt_type == TaskTypeEnum .classification and self .classification_prompt is not None :
106104 return self .classification_prompt
107- if prompt_type == TaskTypeEnum .cluster :
105+ if prompt_type == TaskTypeEnum .cluster and self . classification_prompt is not None :
108106 return self .cluster_prompt
109- if prompt_type == TaskTypeEnum .query :
107+ if prompt_type == TaskTypeEnum .query and self . query_prompt is not None :
110108 return self .query_prompt
111- if prompt_type == TaskTypeEnum .passage :
109+ if prompt_type == TaskTypeEnum .passage and self . passage_prompt is not None :
112110 return self .passage_prompt
113- if prompt_type == TaskTypeEnum .sts :
111+ if prompt_type == TaskTypeEnum .sts and self . sts_prompt is not None :
114112 return self .sts_prompt
115- if prompt_type == TaskTypeEnum .default :
116- return self .default_prompt
117- assert_never (prompt_type )
113+ return self .default_prompt
118114
119115
120116class CrossEncoderConfig (HFModelConfig ):
0 commit comments