Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions dreadnode/eval/hooks/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ async def hook(event: "EvalEvent") -> "EvalReaction | None": # noqa: PLR0911
if create_task:
from dreadnode import task as dn_task

task_kwargs = event.task_kwargs
input_data = event.task_kwargs

@dn_task(
name=f"transform - input ({len(transforms)} transforms)",
Expand All @@ -44,11 +44,11 @@ async def hook(event: "EvalEvent") -> "EvalReaction | None": # noqa: PLR0911
log_output=True,
)
async def apply_task(
data: dict[str, t.Any] = task_kwargs, # Use extracted variable
data: dict[str, t.Any],
) -> dict[str, t.Any]:
return await apply_transforms_to_kwargs(data, transforms)

transformed = await apply_task()
transformed = await apply_task(input_data)
return ModifyInput(task_kwargs=transformed)

# Direct application
Expand All @@ -73,10 +73,12 @@ async def apply_task(
log_inputs=True,
log_output=True,
)
async def apply_task(data: t.Any = output_data) -> t.Any: # Use extracted variable
async def apply_task(
data: t.Any,
) -> t.Any:
return await apply_transforms_to_value(data, transforms)

transformed = await apply_task()
transformed = await apply_task(output_data)
return ModifyOutput(output=transformed)

# Direct application
Expand Down
14 changes: 7 additions & 7 deletions dreadnode/optimization/study.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from pydantic import ConfigDict, Field, FilePath, SkipValidation, computed_field

from dreadnode.common_types import AnyDict
from dreadnode.data_types.message import Message
from dreadnode.error import AssertionFailedError
from dreadnode.eval import InputDataset
from dreadnode.eval.eval import Eval
Expand Down Expand Up @@ -379,17 +380,16 @@ async def _run_evaluation(
)
logger.trace(f"Candidate: {trial.candidate!r}")

# if dataset == [{}] or (isinstance(dataset, list) and len(dataset) == 1 and not dataset[0]):
# # Dataset is empty - this is a Study/Attack where the candidate IS the input
# dataset = [{"message": trial.candidate}]
# dataset_input_mapping = ["message"]
# else:
# dataset_input_mapping = None
dataset_input_mapping = None
# If dataset is empty and candidate is a Message, this is an airt attack scenario
if dataset == [{}] and isinstance(trial.candidate, Message):
dataset = [{"message": trial.candidate}]
dataset_input_mapping = ["message"]

evaluator = Eval(
task=task,
dataset=dataset,
# dataset_input_mapping=dataset_input_mapping,
dataset_input_mapping=dataset_input_mapping,
scorers=scorers,
hooks=self.hooks,
max_consecutive_errors=self.max_consecutive_errors,
Expand Down
Loading
Loading