Skip to content

Commit dfab1e8

Browse files
committed
Lint
Signed-off-by: Claudio Spiess <[email protected]>
1 parent b6b9f2a commit dfab1e8

File tree

7 files changed

+185
-82
lines changed

7 files changed

+185
-82
lines changed

src/pdl/optimize/PDLThread.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,9 @@
11
# pylint: disable=too-many-instance-attributes
2+
import time
3+
from pathlib import Path
4+
from threading import Thread
5+
from typing import Any
6+
27
from pdl.optimize.config_parser import OptimizationConfig
38
from pdl.optimize.util import RETRY_COUNT, TrialOutput, console
49
from pdl.pdl_ast import Program, ScopeType
@@ -7,12 +12,6 @@
712
from pdl.pdl_parser import PDLParseError
813

914

10-
import time
11-
from pathlib import Path
12-
from threading import Thread
13-
from typing import Any
14-
15-
1615
class PDLThread(Thread):
1716
"""Evaluates a candidate (configuration, i.e. fewshots, style) against **one** test example."""
1817

@@ -122,13 +121,13 @@ def run(
122121
else:
123122
message = get_loc_string(exc.loc) + exc.message
124123
console.log(message)
125-
retry = True # tries < RETRY_COUNT
124+
retry = True # tries < RETRY_COUNT
126125
if tries >= RETRY_COUNT:
127126
retry = False
128127
console.log("Retrying: ", retry)
129128
exception = exc
130129
except TimeoutError as exc:
131-
retry = True # tries < RETRY_COUNT
130+
retry = True # tries < RETRY_COUNT
132131
if tries >= RETRY_COUNT:
133132
retry = False
134133
exception = exc
@@ -163,4 +162,4 @@ def run(
163162
example=self.example,
164163
total_tokens=total_tokens,
165164
index=self.index,
166-
)
165+
)

src/pdl/optimize/gsmhard_thread.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,4 @@ def extract_answer(self, document: PdlApply) -> Any:
8181
def answer_correct(self, document: str, answer: Any, truth: Any) -> bool:
8282
answerf = is_float(answer)
8383
truthf = is_float(truth)
84-
return (
85-
answer == truth or answerf == truthf or document.endswith(f" {truth}")
86-
)
84+
return answer == truth or answerf == truthf or document.endswith(f" {truth}")

src/pdl/optimize/mbpp_dataset.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,9 @@ def __init__(self):
2424
MBPP_OUTPUT_NOT_NONE_TASKS,
2525
)
2626

27-
self.mbpp = load_from_disk("../prompt-declaration-language-merge/var/mbpp_trajectified").rename_column(
27+
self.mbpp = load_from_disk(
28+
"../prompt-declaration-language-merge/var/mbpp_trajectified"
29+
).rename_column(
2830
"code",
2931
"canonical_solution",
3032
)

src/pdl/optimize/optimize.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
import argparse
2+
import traceback
23
from enum import Enum
34
from pathlib import Path
4-
import traceback
55

66
import yaml
77
from datasets import concatenate_datasets, load_dataset, load_from_disk
88

9+
from pdl.optimize.config_parser import OptimizationConfig
910
from pdl.optimize.gsmhard_thread import GsmHardTrialThread
1011
from pdl.optimize.mbpp_dataset import MBPPDataset
11-
from pdl.optimize.config_parser import OptimizationConfig
1212
from pdl.optimize.mbpp_thread import MBPPTrialThread
1313

1414
from .fever_thread import FEVERTrialThread

src/pdl/optimize/pdl_optimizer.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,7 @@
2020

2121
from pdl.optimize.config_parser import OptimizationConfig
2222
from pdl.optimize.PDLThread import PDLThread
23-
from pdl.optimize.util import (
24-
CandidateResult,
25-
TrialOutput,
26-
console,
27-
execute_threads,
28-
)
23+
from pdl.optimize.util import CandidateResult, TrialOutput, console, execute_threads
2924
from pdl.pdl_ast import DataBlock, Program
3025
from pdl.pdl_dumper import dump_yaml
3126

@@ -58,9 +53,7 @@ def dump_program(program):
5853

5954

6055
def resave_pdl(input_path: Path, output_path: Path, state: dict) -> int:
61-
with (
62-
input_path.open(encoding="utf-8") as pdl,
63-
):
56+
with (input_path.open(encoding="utf-8") as pdl,):
6457
pdl_program = Program.model_validate(yaml.safe_load(pdl))
6558

6659
for variable, value in state.items():
@@ -149,9 +142,7 @@ def parse_budget(self):
149142
self.time_budget = duration
150143

151144
def load_pdl(self, path: Path) -> Program:
152-
with (
153-
path.open(encoding="utf-8") as pdl,
154-
):
145+
with (path.open(encoding="utf-8") as pdl,):
155146
return Program.model_validate(yaml.safe_load(pdl))
156147

157148
def sample_random_indices(self, dataset: list, size: int) -> list:

src/pdl/pdl_ast.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,7 @@ class Block(BaseModel):
250250
"""Current context
251251
"""
252252
# Fields for internal use
253-
pdl__id: Optional[str] = Field(default="", exclude=True)
253+
pdl__id: Optional[str] = Field(default="", exclude=True)
254254
"""Unique identifier for this block
255255
"""
256256
pdl__result: Optional[Any] = Field(default=None, exclude=True)
@@ -475,7 +475,8 @@ class CodeBlock(BaseCodeBlock):
475475
"""
476476

477477
lang: Annotated[
478-
Literal["python", "command", "jinja", "pdl", "ipython"], BeforeValidator(_ensure_lower)
478+
Literal["python", "command", "jinja", "pdl", "ipython"],
479+
BeforeValidator(_ensure_lower),
479480
]
480481
"""Programming language of the code.
481482
"""

0 commit comments

Comments
 (0)