@@ -59,7 +59,7 @@ class EmbedderConfig(HFModelConfig):
5959 default_prompt : str | None = Field (
6060 None , description = "Default prompt for the model. This is used when no task specific prompt is not provided."
6161 )
62- classifier_prompt : str | None = Field (None , description = "Prompt for classifier." )
62+ classification_prompt : str | None = Field (None , description = "Prompt for classifier." )
6363 cluster_prompt : str | None = Field (None , description = "Prompt for clustering." )
6464 sts_prompt : str | None = Field (None , description = "Prompt for finding most similar sentences." )
6565 query_prompt : str | None = Field (None , description = "Prompt for query." )
@@ -79,8 +79,8 @@ def get_prompt_config(self) -> dict[str, str] | None:
7979 prompts = {}
8080 if self .default_prompt :
8181 prompts [TaskTypeEnum .default .value ] = self .default_prompt
82- if self .classifier_prompt :
83- prompts [TaskTypeEnum .classification .value ] = self .classifier_prompt
82+ if self .classification_prompt :
83+ prompts [TaskTypeEnum .classification .value ] = self .classification_prompt
8484 if self .cluster_prompt :
8585 prompts [TaskTypeEnum .cluster .value ] = self .cluster_prompt
8686 if self .query_prompt :
@@ -91,7 +91,7 @@ def get_prompt_config(self) -> dict[str, str] | None:
9191 prompts [TaskTypeEnum .sts .value ] = self .sts_prompt
9292 return prompts if len (prompts ) > 0 else None
9393
94- def get_prompt_type (self , prompt_type : TaskTypeEnum | None ) -> str | None : # noqa: PLR0911
94+ def get_prompt (self , prompt_type : TaskTypeEnum | None ) -> str | None : # noqa: PLR0911
9595 """Get the prompt type for the given task type.
9696
9797 Args:
@@ -103,7 +103,7 @@ def get_prompt_type(self, prompt_type: TaskTypeEnum | None) -> str | None: # no
103103 if prompt_type is None :
104104 return self .default_prompt
105105 if prompt_type == TaskTypeEnum .classification :
106- return self .classifier_prompt
106+ return self .classification_prompt
107107 if prompt_type == TaskTypeEnum .cluster :
108108 return self .cluster_prompt
109109 if prompt_type == TaskTypeEnum .query :
0 commit comments