Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 89 additions & 3 deletions src/inspect_ai/log/_recorders/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,23 @@
import os
import tempfile
from logging import getLogger
from typing import Any, BinaryIO, Literal, cast
from typing import Any, BinaryIO, Literal, cast, List, Dict, Tuple
from zipfile import ZIP_DEFLATED, ZipFile


from multiprocessing import Pool
import anyio
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, TypeAdapter
from pydantic_core import to_json
from typing_extensions import override
from pydantic_core import from_json

from inspect_ai._util.constants import DESERIALIZING_CONTEXT, LOG_SCHEMA_VERSION
from inspect_ai._util.error import EvalError
from inspect_ai._util.file import FileSystem, dirname, file, filesystem
from inspect_ai._util.json import jsonable_python
from inspect_ai._util.trace import trace_action

from inspect_ai.model import ChatMessage
from .._log import (
EvalLog,
EvalPlan,
Expand Down Expand Up @@ -388,6 +391,89 @@ def _zip_writestr(self, filename: str, data: Any) -> None:
)


def _process_single_file_field(
args: Tuple[str, str], field: str
) -> List[Dict[str, Any]]:
zip_location, filename = args
with ZipFile(zip_location, mode="r") as zip:
with zip.open(filename, "r") as f:
data = json.load(f)
return data[field] if field in data else []


def _process_single_file_field_validated(
args: Tuple[str, str], field: str
) -> List[ChatMessage]:
zip_location, filename = args
with ZipFile(zip_location, mode="r") as zip:
with zip.open(filename, "r") as f:
data = json.load(f)
user_list_adapter = TypeAdapter(list[ChatMessage])
return (
user_list_adapter.validate_python(data[field]) if field in data else []
)


def process_single_file_messages_validated(
args: Tuple[str, str],
) -> List[ChatMessage]:
return _process_single_file_field_validated(args, "messages")


def process_single_file_messages(args: Tuple[str, str]) -> List[Dict[str, Any]]:
return _process_single_file_field(args, "messages")


def process_single_file_events(args: Tuple[str, str]) -> List[Dict[str, Any]]:
return _process_single_file_field(args, "events")


def read_eval_log_as_json(
location: str, field: str = "messages"
) -> List[List[Dict[str, Any]]]:
with ZipFile(location, mode="r") as zip:
json_files = [
name
for name in zip.namelist()
if name.startswith(f"{SAMPLES_DIR}/") and name.endswith(".json")
]

args = [(location, name) for name in json_files]

if field == "messages":
processor = process_single_file_messages
elif field == "events":
processor = process_single_file_events

with Pool() as pool:
samples = pool.map(processor, args)

return samples


def read_eval_log_as_json_validated(
location: str, field: str = "messages"
) -> List[List[Dict[str, Any]]]:
with ZipFile(location, mode="r") as zip:
json_files = [
name
for name in zip.namelist()
if name.startswith(f"{SAMPLES_DIR}/") and name.endswith(".json")
]

args = [(location, name) for name in json_files]

if field == "messages":
processor = process_single_file_messages_validated
elif field == "events":
processor = process_single_file_events

with Pool() as pool:
samples = pool.map(processor, args)

return samples


def _read_log(log: BinaryIO, location: str, header_only: bool = False) -> EvalLog:
with ZipFile(log, mode="r") as zip:
evalLog = _read_header(zip, location)
Expand Down
Loading