Skip to content

Commit b87591d

Browse files
committed
typing fix
1 parent 7080bf1 commit b87591d

File tree

4 files changed

+17
-18
lines changed

4 files changed

+17
-18
lines changed

promptolution/helpers.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -161,9 +161,9 @@ def get_llm(model_id: Optional[str] = None, config: Optional["ExperimentConfig"]
161161
def get_task(
162162
df: pd.DataFrame,
163163
config: "ExperimentConfig",
164-
task_type: TaskType = None,
165-
judge_llm: "BaseLLM" = None,
166-
reward_function: Callable = None,
164+
task_type: Optional["TaskType"] = None,
165+
judge_llm: Optional["BaseLLM"] = None,
166+
reward_function: Optional[Callable] = None,
167167
) -> "BaseTask":
168168
"""Get the task based on the provided DataFrame and configuration.
169169
@@ -198,7 +198,7 @@ def get_optimizer(
198198
predictor: "BasePredictor",
199199
meta_llm: "BaseLLM",
200200
task: "BaseTask",
201-
optimizer: Optional[OptimizerType] = None,
201+
optimizer: Optional["OptimizerType"] = None,
202202
task_description: Optional[str] = None,
203203
config: Optional["ExperimentConfig"] = None,
204204
) -> "BaseOptimizer":
@@ -292,7 +292,7 @@ def get_exemplar_selector(
292292
raise ValueError(f"Unknown exemplar selector: {name}")
293293

294294

295-
def get_predictor(downstream_llm=None, type: PredictorType = "marker", *args, **kwargs) -> "BasePredictor":
295+
def get_predictor(downstream_llm=None, type: "PredictorType" = "marker", *args, **kwargs) -> "BasePredictor":
296296
"""Factory function to create and return a predictor instance.
297297
298298
This function supports three types of predictors:

promptolution/tasks/base_task.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def __init__(
2727
y_column: Optional[str] = None,
2828
task_description: Optional[str] = None,
2929
n_subsamples: int = 30,
30-
eval_strategy: EvalStrategy = "full",
30+
eval_strategy: "EvalStrategy" = "full",
3131
seed: int = 42,
3232
config: Optional["ExperimentConfig"] = None,
3333
) -> None:
@@ -70,7 +70,7 @@ def __init__(
7070
self.eval_cache: Dict[Tuple[str, str, str], float] = {} # (prompt, x, y): scores per datapoint
7171
self.seq_cache: Dict[Tuple[str, str, str], str] = {} # (prompt, x, y): generating sequence per datapoint
7272

73-
def subsample(self, eval_strategy: EvalStrategy = None) -> Tuple[List[str], List[str]]:
73+
def subsample(self, eval_strategy: "EvalStrategy" = None) -> Tuple[List[str], List[str]]:
7474
"""Subsample the dataset based on the specified parameters.
7575
7676
Args:
@@ -170,7 +170,7 @@ def evaluate(
170170
system_prompts: Optional[Union[str, List[str]]] = None,
171171
return_agg_scores: Literal[True] = True,
172172
return_seq: Literal[False] = False,
173-
eval_strategy: Optional[EvalStrategy] = None,
173+
eval_strategy: Optional["EvalStrategy"] = None,
174174
) -> List[float]:
175175
...
176176

@@ -182,7 +182,7 @@ def evaluate(
182182
system_prompts: Optional[Union[str, List[str]]] = None,
183183
return_agg_scores: Literal[False] = False,
184184
return_seq: Literal[False] = False,
185-
eval_strategy: Optional[EvalStrategy] = None,
185+
eval_strategy: Optional["EvalStrategy"] = None,
186186
) -> List[List[float]]:
187187
...
188188

@@ -194,7 +194,7 @@ def evaluate(
194194
system_prompts: Optional[Union[str, List[str]]] = None,
195195
return_agg_scores: Literal[False] = False,
196196
return_seq: Literal[True] = True,
197-
eval_strategy: Optional[EvalStrategy] = None,
197+
eval_strategy: Optional["EvalStrategy"] = None,
198198
) -> Tuple[List[List[float]], List[List[str]]]:
199199
...
200200

@@ -206,7 +206,7 @@ def evaluate(
206206
system_prompts: Optional[Union[str, List[str]]] = None,
207207
return_agg_scores: Literal[True] = True,
208208
return_seq: Literal[False] = False,
209-
eval_strategy: Optional[EvalStrategy] = None,
209+
eval_strategy: Optional["EvalStrategy"] = None,
210210
) -> List[float]:
211211
...
212212

@@ -218,7 +218,7 @@ def evaluate(
218218
system_prompts: Optional[Union[str, List[str]]] = None,
219219
return_agg_scores: Literal[False] = False,
220220
return_seq: Literal[False] = False,
221-
eval_strategy: Optional[EvalStrategy] = None,
221+
eval_strategy: Optional["EvalStrategy"] = None,
222222
) -> List[List[float]]:
223223
...
224224

@@ -230,7 +230,7 @@ def evaluate(
230230
system_prompts: Optional[Union[str, List[str]]] = None,
231231
return_agg_scores: Literal[False] = False,
232232
return_seq: Literal[True] = True,
233-
eval_strategy: Optional[EvalStrategy] = None,
233+
eval_strategy: Optional["EvalStrategy"] = None,
234234
) -> Tuple[List[List[float]], List[List[str]]]:
235235
...
236236

@@ -241,7 +241,7 @@ def evaluate(
241241
system_prompts: Optional[Union[str, List[str]]] = None,
242242
return_agg_scores: bool = True,
243243
return_seq: bool = False,
244-
eval_strategy: Optional[EvalStrategy] = None,
244+
eval_strategy: Optional["EvalStrategy"] = None,
245245
) -> Union[List[float], List[List[float]], Tuple[List[List[float]], List[List[str]]]]:
246246
"""Evaluate a set of prompts using a given predictor.
247247

promptolution/tasks/judge_tasks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def __init__(
7171
y_column: Optional[str] = None,
7272
task_description: Optional[str] = None,
7373
n_subsamples: int = 30,
74-
eval_strategy: EvalStrategy = "full",
74+
eval_strategy: "EvalStrategy" = "full",
7575
seed: int = 42,
7676
judge_prompt: Optional[str] = None,
7777
min_score: float = -5.0,

promptolution/tasks/reward_tasks.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
"""Module for Reward tasks."""
22

33

4-
import numpy as np
54
import pandas as pd
65

7-
from typing import TYPE_CHECKING, Callable, List, Literal, Optional, Union
6+
from typing import TYPE_CHECKING, Callable, List, Optional
87

98
from promptolution.tasks.base_task import BaseTask
109

@@ -27,7 +26,7 @@ def __init__(
2726
x_column: str = "x",
2827
task_description: Optional[str] = None,
2928
n_subsamples: int = 30,
30-
eval_strategy: EvalStrategy = "full",
29+
eval_strategy: "EvalStrategy" = "full",
3130
seed: int = 42,
3231
config: Optional["ExperimentConfig"] = None,
3332
) -> None:

0 commit comments

Comments
 (0)