Skip to content

Commit 0885263

Browse files
committed
fix(optimization): resolve stop condition race and console stale display issues
1 parent 97a08f3 commit 0885263

File tree

2 files changed

+72
-14
lines changed

2 files changed

+72
-14
lines changed

dreadnode/optimization/console.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,10 +303,23 @@ def _handle_event(self, event: StudyEvent[t.Any]) -> None: # noqa: PLR0912
303303
self._trials_completed += 1
304304
self._completed_evals += 1
305305
self._total_cost += event.trial.cost
306+
307+
# Check if this trial is the new best (inline check to avoid stale display)
308+
# This handles the case where NewBestTrialFound event comes after rendering
309+
if (
310+
not event.trial.is_probe
311+
and event.trial.status == "finished"
312+
and (self._best_trial is None or event.trial.score > self._best_trial.score)
313+
):
314+
self._best_trial = event.trial
306315
elif isinstance(event, NewBestTrialFound):
307316
self._best_trial = event.trial
308317
elif isinstance(event, StudyEnd):
309318
self._result = event.result
319+
# Update best trial from final result in case some trials completed
320+
# after stop condition but before we received their events
321+
if event.result.best_trial:
322+
self._best_trial = event.result.best_trial
310323

311324
self._progress.update(self._progress_task_id, completed=self._completed_evals)
312325

dreadnode/optimization/study.py

Lines changed: 59 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ def add_stop_condition(self, condition: StudyStopCondition[CandidateT]) -> te.Se
246246
self.stop_conditions.append(condition)
247247
return self
248248

249-
def _resolve_dataset(self, dataset: t.Any) -> list[AnyDict]:
249+
async def _resolve_dataset(self, dataset: t.Any) -> list[AnyDict]:
250250
"""
251251
Resolve dataset to a list in memory.
252252
Handles list, file path, or callable datasets.
@@ -266,10 +266,7 @@ def _resolve_dataset(self, dataset: t.Any) -> list[AnyDict]:
266266
if callable(dataset):
267267
result = dataset()
268268
if inspect.isawaitable(result):
269-
raise ValueError(
270-
"Async dataset callables not supported with COA 1 "
271-
"(requires eager materialization)"
272-
)
269+
result = await result
273270
return list(result) if not isinstance(result, list) else result
274271

275272
return [{}]
@@ -433,28 +430,29 @@ async def _run_evaluation(
433430
trial: Trial[CandidateT],
434431
) -> t.Any:
435432
"""Run the evaluation with the given task, dataset, and scorers."""
436-
resolved_dataset = self._resolve_dataset(dataset)
433+
resolved_dataset = await self._resolve_dataset(dataset)
437434
param_name = self._infer_candidate_param(task, trial.candidate)
438435

439436
logger.debug(
440437
f"Augmenting {len(resolved_dataset)} dataset rows with candidate "
441438
f"as parameter: {param_name}"
442439
)
443440

441+
# Check for collisions before augmentation (check all rows, not just first)
442+
if resolved_dataset:
443+
collision_count = sum(1 for row in resolved_dataset if param_name in row)
444+
if collision_count > 0:
445+
logger.warning(
446+
f"Parameter '{param_name}' exists in {collision_count}/{len(resolved_dataset)} "
447+
f"dataset rows - candidate will override existing values"
448+
)
449+
444450
# Augment every row with the candidate
445451
augmented_dataset = [{**row, param_name: trial.candidate} for row in resolved_dataset]
446452

447-
# Warn on collisions
448-
if resolved_dataset and param_name in resolved_dataset[0]:
449-
logger.warning(
450-
f"Parameter '{param_name}' already exists in dataset - "
451-
f"candidate will override existing values"
452-
)
453-
454453
evaluator = Eval(
455454
task=task,
456455
dataset=augmented_dataset,
457-
dataset_input_mapping=[param_name],
458456
scorers=scorers,
459457
hooks=self.hooks,
460458
max_consecutive_errors=self.max_consecutive_errors,
@@ -560,13 +558,22 @@ async def process_search(
560558
with contextlib.suppress(asyncio.InvalidStateError):
561559
item._future.set_result(item) # noqa: SLF001
562560

561+
# Track in-flight trials to know when to stop after stop condition
562+
in_flight_trials: set[str] = set()
563+
563564
async with stream_map_and_merge(
564565
source=self.search_strategy(optimization_context),
565566
processor=process_search,
566567
limit=self.max_evals,
567568
concurrency=self.concurrency * 2,
568569
) as events:
569570
async for event in events:
571+
# Track trial lifecycle for proper draining
572+
if isinstance(event, TrialAdded):
573+
in_flight_trials.add(event.trial.id)
574+
elif isinstance(event, (TrialComplete, TrialPruned)):
575+
in_flight_trials.discard(event.trial.id)
576+
570577
yield event
571578

572579
if isinstance(event, (TrialComplete, TrialPruned)):
@@ -597,6 +604,44 @@ async def process_search(
597604
break
598605

599606
if stop_condition_met:
607+
# Drain only in-flight trials (those started but not yet completed)
608+
logger.debug(
609+
f"Draining {len(in_flight_trials)} in-flight trials before stopping..."
610+
)
611+
async for remaining_event in events:
612+
# Skip new TrialAdded events - don't start new trials after stop
613+
if isinstance(remaining_event, TrialAdded):
614+
logger.trace(f"Skipping new trial {remaining_event.trial.id} after stop")
615+
continue
616+
617+
# Track trial completion
618+
if isinstance(remaining_event, (TrialComplete, TrialPruned)):
619+
in_flight_trials.discard(remaining_event.trial.id)
620+
621+
yield remaining_event
622+
623+
# Update best trial if a better one completes while draining
624+
if (
625+
isinstance(remaining_event, (TrialComplete, TrialPruned))
626+
and not remaining_event.trial.is_probe
627+
and remaining_event.trial.status == "finished"
628+
and (best_trial is None or remaining_event.trial.score > best_trial.score)
629+
):
630+
best_trial = remaining_event.trial
631+
logger.success(
632+
f"New best trial (while draining): "
633+
f"id={best_trial.id}, "
634+
f"step={best_trial.step}, "
635+
f"score={best_trial.score:.5f}"
636+
)
637+
yield NewBestTrialFound(
638+
study=self, trials=all_trials, probes=all_probes, trial=best_trial
639+
)
640+
641+
# Stop draining once all in-flight trials are done
642+
if not in_flight_trials:
643+
logger.debug("All in-flight trials completed, stopping.")
644+
break
600645
break
601646

602647
stop_reason = (

0 commit comments

Comments
 (0)