Skip to content

Commit 619639a

Browse files
authored
Implement task evaluation pipeline
Implement the task evaluation pipeline with vanilla score function.
2 parents 0324e19 + dfac257 commit 619639a

14 files changed

+1422
-91
lines changed

README.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,14 @@ run:
1818
python3 -m poetry install --with test
1919
```
2020

21+
### [Optional] Google Cloud Authentication
22+
23+
The capability evaluation logs (evaluated using [Inspect](https://inspect.aisi.org.uk/)) are stored in a GCP bucket. Use the following command to log in using your GCP account:
24+
25+
```bash
26+
gcloud auth application-default login
27+
```
28+
2129
### Run pipeline with default config
2230

2331
Note: Please set the following env vars before running the command.
@@ -30,6 +38,8 @@ Note: Please set the following env vars before running the command.
3038
- LANGSMITH_ENDPOINT="https://api.smith.langchain.com"
3139
- LANGSMITH_API_KEY=<langsmith_api_key>
3240
- LANGSMITH_PROJECT="automated_capability_evaluation"
41+
- GCP env vars:
42+
- GOOGLE_CLOUD_PROJECT=<project_id>
3343

3444
```bash
3545
python3 src/run.py

poetry.lock

Lines changed: 870 additions & 21 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,13 @@ authors = [
1111
dynamic = ["version"]
1212
dependencies = [
1313
"datasets>=3.2.0",
14+
"google-cloud-storage>=3.0.0",
1415
"hydra-core>=1.3.2",
16+
"inspect-ai>=0.3.80",
1517
"langchain_openai>=0.3.6",
1618
"langchain>=0.3.19",
1719
"omegaconf>=2.3.0",
18-
"openai>=1.61.1",
20+
"openai>=1.68.0",
1921
"ratelimit>=2.2.1",
2022
"torchvision (>=0.21.0,<0.22.0)",
2123
"torchaudio (>=2.6.0,<3.0.0)",

src/capability.py

Lines changed: 151 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,30 @@
22
import json
33
import os
44
import re
5+
import shutil
56
import sys
67
from collections import defaultdict
78
from typing import Any, Dict, List, Tuple
89

910
from src.model import Model
10-
from src.utils.capability_utils import parse_python_class_str, read_score_inspect_json
11-
from src.utils.constants import (
12-
NO_ANSWER_STR,
13-
NON_SEED_CAPABILITIES_SCORE_DIR,
14-
SEED_CAPABILITIES_SCORE_DIR,
15-
TAB_W_SPACES,
11+
from src.utils import constants
12+
from src.utils.capability_utils import (
13+
parse_python_class_str,
14+
read_score_inspect_json,
15+
run_inspect_evals,
16+
)
17+
from src.utils.data_utils import (
18+
list_dir,
19+
load_data,
20+
path_exists,
21+
transfer_inspect_log_to_gcp,
1622
)
17-
from src.utils.data_utils import load_data
1823
from src.utils.prompts import TASK_SOLVER_SYSTEM_PROMPT
24+
from src.utils.templates import (
25+
INSPECT_EVALS_INIT_FILE_TEMPLATE,
26+
INSPECT_EVALS_README_FILE_TEMPLATE,
27+
INSPECT_EVALS_SCRIPT_FILE_TEMPLATE,
28+
)
1929

2030

2131
class CapabilitySeedDataset:
@@ -100,9 +110,9 @@ def __init__(self, capability_dir: str) -> None:
100110
self._load_capability_repr_class()
101111

102112
self.score_dir = (
103-
SEED_CAPABILITIES_SCORE_DIR
113+
constants.SEED_CAPABILITIES_SCORE_DIR
104114
if self.is_seed
105-
else NON_SEED_CAPABILITIES_SCORE_DIR
115+
else constants.NON_SEED_CAPABILITIES_SCORE_DIR
106116
)
107117

108118
@classmethod
@@ -208,11 +218,11 @@ def load_scores(self, scores_dir: str | None = None) -> Dict[str, float]:
208218
"""
209219
scores_dir = scores_dir if scores_dir else self.score_dir
210220
scores_dict = defaultdict(float)
211-
for model in os.listdir(scores_dir):
221+
for model in list_dir(scores_dir):
212222
scores_file = os.path.join(
213223
scores_dir, model, self.domain, f"{self.name}.json"
214224
)
215-
if os.path.isfile(scores_file):
225+
if path_exists(scores_file):
216226
scores_dict[model] = read_score_inspect_json(scores_file)
217227
return scores_dict
218228

@@ -286,8 +296,8 @@ def add_and_update_tasks(self, tasks: List[Dict[str, Any]]) -> None:
286296
# Update the capability class python file
287297
# Extract str which contains the repr_tasks dictionary
288298
# TODO: Since these are hardcoded, update when the format changes
289-
prefix_str = f"def repr_tasks() -> dict[str, dict]:\n{TAB_W_SPACES}{TAB_W_SPACES}return "
290-
suffix_str = f"\n\n{TAB_W_SPACES}@staticmethod\n{TAB_W_SPACES}def get_instructions(t: dict) -> str:"
299+
prefix_str = f"def repr_tasks() -> dict[str, dict]:\n{constants.TAB_W_SPACES}{constants.TAB_W_SPACES}return "
300+
suffix_str = f"\n\n{constants.TAB_W_SPACES}@staticmethod\n{constants.TAB_W_SPACES}def get_instructions(t: dict) -> str:"
291301
prev_repr_tasks_str = self.capability_repr_class_str.split(prefix_str)[
292302
1
293303
].split(suffix_str)[0]
@@ -412,7 +422,7 @@ def _solve_task(
412422
# and the answer is incomplete?
413423
answer_pattern = r"(?i)ANSWER\s*:\s*([^\n]+)"
414424
match = re.search(answer_pattern, response)
415-
answer = match.group(1) if match else NO_ANSWER_STR
425+
answer = match.group(1) if match else constants.NO_ANSWER_STR
416426
metadata = {
417427
"raw_response": response,
418428
"api_metadata": metadata,
@@ -466,37 +476,152 @@ def get_tasks(self) -> List[Dict[str, Any]]:
466476
"""
467477
return self._data
468478

469-
def _create_inspect_file(self) -> None:
479+
def _create_inspect_file(self, path: str) -> None:
470480
"""
471481
Implement pipeline to evaluate the capability using the inspect framework.
472482
473483
This involves converting the METR format to inspect solvers and scorers.
474484
"""
475-
raise NotImplementedError
485+
# Create JSONL dataset and store it under the inspect path
486+
dataset = self.get_tasks()
487+
dataset_metadata_keys = [
488+
k for k in list(dataset[0].keys()) if k not in ["id", "problem", "answer"]
489+
]
490+
# Write data to a dataset JSONL file
491+
with open(os.path.join(path, "dataset.jsonl"), "w") as f:
492+
for elm in dataset:
493+
f.write(json.dumps(elm) + "\n")
494+
495+
# Create __init__.py and README files
496+
# TODO: Add more details to the README file
497+
init_file_content = INSPECT_EVALS_INIT_FILE_TEMPLATE.format(
498+
capability_name=self.name,
499+
).strip("\n")
500+
with open(os.path.join(path, "__init__.py"), "w") as f:
501+
f.write(init_file_content)
502+
readme_file_content = INSPECT_EVALS_README_FILE_TEMPLATE.format(
503+
capability_name=self.name,
504+
capability_description=self.description,
505+
).strip("\n")
506+
with open(os.path.join(path, "README.md"), "w") as f:
507+
f.write(readme_file_content)
508+
509+
# Create inspect evals script file
510+
# TODO: How to handle more involved score functions?
511+
# TODO: Do we need system prompt?
512+
instruction_template = self.capability_repr_class.get_instructions(
513+
{"problem": "{prompt}"}
514+
)
515+
score_func_prefix = f"@staticmethod\n{constants.TAB_W_SPACES}def score"
516+
score_func_prefix_new = (
517+
f"async {score_func_prefix.split(constants.TAB_W_SPACES)[1]}".replace(
518+
"score", "_score"
519+
)
520+
)
521+
score_func_str = f"{score_func_prefix_new}{self.capability_repr_class_str.split(score_func_prefix)[1].replace((constants.TAB_W_SPACES + constants.TAB_W_SPACES), constants.TAB_W_SPACES)}".strip(
522+
"`"
523+
).strip("\n")
524+
script_file_content = INSPECT_EVALS_SCRIPT_FILE_TEMPLATE.format(
525+
capability_name=self.name,
526+
dataset_metadata_keys=json.dumps(dataset_metadata_keys),
527+
prompt_template=instruction_template,
528+
score_func_t_dict_str='{"answer": target.text}',
529+
score_func_str=score_func_str,
530+
)
531+
script_file_path = os.path.join(path, f"{self.name}.py")
532+
with open(script_file_path, "w") as f:
533+
f.write(script_file_content)
534+
# TODO: Validate formatting of script file
535+
_ = _import_from_path(
536+
module_name=f"{self.name}_inspect_eval_script", file_path=script_file_path
537+
)
476538

477-
def _evaluate_using_inspect(self, subject_llm: Model) -> None: # noqa: D102
539+
def _evaluate_using_inspect(self, subject_llm: Model, **kwargs: Any) -> None:
478540
"""
479-
Evaluate subject LLM on the capability using the inspect framework.
541+
Evaluate the subject LLM on the capability using the Inspect framework.
480542
481-
Args
482-
----
483-
subject_llm : Model
484-
The LLM to use for evaluation.
543+
This method uses the Inspect evaluation framework to assess the performance of
544+
the provided language model (LLM) on a specific capability. It ensures that the
545+
required evaluation files exist, temporarily stores logs locally, and transfers
546+
them to a GCP bucket after the evaluation is complete.
547+
548+
Args:
549+
subject_llm (Model): The LLM model to evaluate.
550+
**kwargs (Any): Additional args for running the evals.
551+
552+
Raises
553+
------
554+
FileNotFoundError: If the required Inspect evaluation path does not exist.
485555
"""
486-
raise NotImplementedError
556+
inspect_path = os.path.join(constants.BASE_INSPECT_EVALS_DIR, self.name)
557+
if not os.path.exists(inspect_path):
558+
raise FileNotFoundError(
559+
f"Inspect evaluation path does not exist: {inspect_path}. "
560+
"Please ensure the inspect files are created before evaluation."
561+
)
562+
# Temporarily store the logs locally and then transfer them to the GCP bucket,
563+
# since Inspect does not support GCP bucket paths for storing logs
564+
log_dir = os.path.join(
565+
self.score_dir.replace(
566+
constants.GCP_BASE_ARTIFACTS_DIR, constants.BASE_ARTIFACTS_DIR
567+
),
568+
subject_llm.get_model_name(),
569+
self.domain,
570+
self.name,
571+
)
572+
os.makedirs(log_dir, exist_ok=True)
487573

488-
def evaluate(self, subject_llms: List[Model]) -> None:
574+
run_inspect_evals(
575+
path=self.name,
576+
model=subject_llm,
577+
log_dir=log_dir,
578+
**kwargs,
579+
)
580+
581+
# Transfer the logs to the GCP bucket
582+
transfer_inspect_log_to_gcp(
583+
src_dir=log_dir,
584+
gcp_dir=log_dir.replace(
585+
constants.BASE_ARTIFACTS_DIR, constants.GCP_BASE_ARTIFACTS_DIR
586+
),
587+
)
588+
# Remove the local logs
589+
shutil.rmtree(log_dir)
590+
591+
def evaluate(
592+
self, subject_llms: List[Model], gen_args: List[Dict[Any, Any]]
593+
) -> None:
489594
"""
490595
Evaluate the provided subject LLMs on the capability.
491596
492597
Args
493598
----
494599
subject_llms : List[Model]
495600
The list of LLMs to use for evaluation.
601+
gen_args : List[Dict[Any, Any]]
602+
The list of generation configurations corresponding to each LLM.
496603
"""
604+
assert len(subject_llms) == len(gen_args), (
605+
"Each subject LLM must have a corresponding generation config."
606+
)
607+
# Create inspect script if evaluating for the first time
608+
inspect_path = os.path.join(constants.BASE_INSPECT_EVALS_DIR, self.name)
609+
if not os.path.exists(inspect_path):
610+
os.makedirs(inspect_path)
611+
self._create_inspect_file(path=inspect_path)
612+
613+
# Change dir to where inspect eval scrips are stored
614+
# because inspect evals does not support non-relative paths
615+
cwd = os.getcwd()
616+
os.chdir(constants.BASE_INSPECT_EVALS_DIR)
497617
# TODO: Run asynchronosly
498-
for model in subject_llms:
499-
self._evaluate_using_inspect(model)
618+
for model_idx, model in enumerate(subject_llms):
619+
self._evaluate_using_inspect(
620+
subject_llm=model,
621+
**gen_args[model_idx],
622+
)
623+
# Revert to original working dir after evaluation
624+
os.chdir(cwd)
500625

501626

502627
def _import_from_path(module_name: str, file_path: str) -> Any:

src/cfg/run_cfg.yaml

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,18 @@ scientist_llm:
1212
max_tokens: 64
1313

1414
subject_llm:
15-
name: Meta-Llama-3.1-70B-Instruct
15+
name: gpt-4o-mini # Meta-Llama-3.1-70B-Instruct
16+
generation_cfg:
17+
temperature: 0.7
18+
max_tokens: 8
1619

1720
prompt_cfg:
1821
sys_msg: Complete the given task to the best of your ability.
1922

2023
capabilities_cfg:
2124
capabilities_dir: /fs01/projects/aieng/public/ace/artifacts
22-
results_dir: /fs01/projects/aieng/public/ace/artifacts
25+
results_dir: gs://ace-artifacts
26+
inspect_evals_dir: /fs01/projects/aieng/public/ace/inspect_evals/src/ace_evals
2327
domain: math
2428
# Number of seed capabilities to use for initial capability generation
2529
# Set to -1 to use all seed capabilities
@@ -33,6 +37,9 @@ capabilities_cfg:
3337
# Set this flag to true to use representative tasks
3438
# as few shot examples for task generation
3539
task_gen_few_shot: true
40+
# Number of tasks to evaluate for each capability
41+
# Set to -1 to evaluate all tasks
42+
num_eval_tasks_per_capability: 1
3643

3744
lbo_cfg:
3845
# Number of capabilities to generate using LBO

src/create_seed_capabilities.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from omegaconf import DictConfig
99

1010
from capability import CapabilitySeedDataset
11-
from utils.constants import GSM8K_SCORE_FUNC
11+
from utils import constants
1212
from utils.templates import CAPABILITY_CLASS_TEMPLATE
1313

1414

@@ -229,7 +229,7 @@ def main(cfg: DictConfig) -> None:
229229
capability_data=math_tasks["tasks"],
230230
capability_repr_tasks=capability_repr_tasks,
231231
capability_instructions=capability_instructions,
232-
capability_score_func=GSM8K_SCORE_FUNC.strip(
232+
capability_score_func=constants.GSM8K_SCORE_FUNC.strip(
233233
"\n"
234234
), # TODO: Change this to MATHEMATICS_SCORE_FUNC after figuring out how to implement complex score functions
235235
source_dataset=dataset.name,
@@ -266,7 +266,7 @@ def main(cfg: DictConfig) -> None:
266266
capability_data=gsm_tasks,
267267
capability_repr_tasks=capability_repr_tasks,
268268
capability_instructions=capability_instructions,
269-
capability_score_func=GSM8K_SCORE_FUNC.strip("\n"),
269+
capability_score_func=constants.GSM8K_SCORE_FUNC.strip("\n"),
270270
source_dataset=dataset.name,
271271
)
272272
print(f"Created capability {capability_name} with {len(gsm_tasks)} tasks.")

src/generate_capabilities.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77

88
from src.capability import Capability
99
from src.model import Model
10+
from src.utils import constants
1011
from src.utils.capability_utils import extract_and_parse_response
11-
from src.utils.constants import BASE_ARTIFACTS_DIR
1212
from src.utils.prompts import (
1313
CAPABILITY_GENERATION_SYSTEM_PROMPT,
1414
CAPABILITY_GENERATION_USER_PROMPT,
@@ -159,7 +159,9 @@ def generate_capabilities_using_llm(
159159
and metadata about the generation process.
160160
"""
161161
# Select seed capabilities
162-
seed_capability_dir = os.path.join(BASE_ARTIFACTS_DIR, "seed_capabilities", domain)
162+
seed_capability_dir = os.path.join(
163+
constants.BASE_ARTIFACTS_DIR, "seed_capabilities", domain
164+
)
163165
seed_capabilities = _sample_seed_capabilities(
164166
seed_capability_dir=seed_capability_dir,
165167
num_seed_capabilities=num_seed_capabilities,
@@ -263,11 +265,13 @@ def generate_capabilities(
263265
# Set the base capability directory
264266
if "trial_run" in kwargs:
265267
base_capability_dir = os.path.join(
266-
BASE_ARTIFACTS_DIR, f"capabilities_{kwargs['run_id']}", domain
268+
constants.BASE_ARTIFACTS_DIR, f"capabilities_{kwargs['run_id']}", domain
267269
)
268270
os.makedirs(base_capability_dir, exist_ok=True)
269271
else:
270-
base_capability_dir = os.path.join(BASE_ARTIFACTS_DIR, "capabilities", domain)
272+
base_capability_dir = os.path.join(
273+
constants.BASE_ARTIFACTS_DIR, "capabilities", domain
274+
)
271275

272276
# Fetch previously generated capabilities, if any
273277
prev_capabilities = _get_previous_capabilities(capability_dir=base_capability_dir)

0 commit comments

Comments
 (0)