@@ -246,7 +246,7 @@ def add_stop_condition(self, condition: StudyStopCondition[CandidateT]) -> te.Se
246246 self .stop_conditions .append (condition )
247247 return self
248248
249- def _resolve_dataset (self , dataset : t .Any ) -> list [AnyDict ]:
249+ async def _resolve_dataset (self , dataset : t .Any ) -> list [AnyDict ]:
250250 """
251251 Resolve dataset to a list in memory.
252252 Handles list, file path, or callable datasets.
@@ -266,10 +266,7 @@ def _resolve_dataset(self, dataset: t.Any) -> list[AnyDict]:
266266 if callable (dataset ):
267267 result = dataset ()
268268 if inspect .isawaitable (result ):
269- raise ValueError (
270- "Async dataset callables not supported with COA 1 "
271- "(requires eager materialization)"
272- )
269+ result = await result
273270 return list (result ) if not isinstance (result , list ) else result
274271
275272 return [{}]
@@ -433,28 +430,29 @@ async def _run_evaluation(
433430 trial : Trial [CandidateT ],
434431 ) -> t .Any :
435432 """Run the evaluation with the given task, dataset, and scorers."""
436- resolved_dataset = self ._resolve_dataset (dataset )
433+ resolved_dataset = await self ._resolve_dataset (dataset )
437434 param_name = self ._infer_candidate_param (task , trial .candidate )
438435
439436 logger .debug (
440437 f"Augmenting { len (resolved_dataset )} dataset rows with candidate "
441438 f"as parameter: { param_name } "
442439 )
443440
441+ # Check for collisions before augmentation (check all rows, not just first)
442+ if resolved_dataset :
443+ collision_count = sum (1 for row in resolved_dataset if param_name in row )
444+ if collision_count > 0 :
445+ logger .warning (
446+ f"Parameter '{ param_name } ' exists in { collision_count } /{ len (resolved_dataset )} "
447+ f"dataset rows - candidate will override existing values"
448+ )
449+
444450 # Augment every row with the candidate
445451 augmented_dataset = [{** row , param_name : trial .candidate } for row in resolved_dataset ]
446452
447- # Warn on collisions
448- if resolved_dataset and param_name in resolved_dataset [0 ]:
449- logger .warning (
450- f"Parameter '{ param_name } ' already exists in dataset - "
451- f"candidate will override existing values"
452- )
453-
454453 evaluator = Eval (
455454 task = task ,
456455 dataset = augmented_dataset ,
457- dataset_input_mapping = [param_name ],
458456 scorers = scorers ,
459457 hooks = self .hooks ,
460458 max_consecutive_errors = self .max_consecutive_errors ,
@@ -560,13 +558,22 @@ async def process_search(
560558 with contextlib .suppress (asyncio .InvalidStateError ):
561559 item ._future .set_result (item ) # noqa: SLF001
562560
561+ # Track in-flight trials to know when to stop after stop condition
562+ in_flight_trials : set [str ] = set ()
563+
563564 async with stream_map_and_merge (
564565 source = self .search_strategy (optimization_context ),
565566 processor = process_search ,
566567 limit = self .max_evals ,
567568 concurrency = self .concurrency * 2 ,
568569 ) as events :
569570 async for event in events :
571+ # Track trial lifecycle for proper draining
572+ if isinstance (event , TrialAdded ):
573+ in_flight_trials .add (event .trial .id )
574+ elif isinstance (event , (TrialComplete , TrialPruned )):
575+ in_flight_trials .discard (event .trial .id )
576+
570577 yield event
571578
572579 if isinstance (event , (TrialComplete , TrialPruned )):
@@ -597,6 +604,44 @@ async def process_search(
597604 break
598605
599606 if stop_condition_met :
607+ # Drain only in-flight trials (those started but not yet completed)
608+ logger .debug (
609+ f"Draining { len (in_flight_trials )} in-flight trials before stopping..."
610+ )
611+ async for remaining_event in events :
612+ # Skip new TrialAdded events - don't start new trials after stop
613+ if isinstance (remaining_event , TrialAdded ):
614+ logger .trace (f"Skipping new trial { remaining_event .trial .id } after stop" )
615+ continue
616+
617+ # Track trial completion
618+ if isinstance (remaining_event , (TrialComplete , TrialPruned )):
619+ in_flight_trials .discard (remaining_event .trial .id )
620+
621+ yield remaining_event
622+
623+ # Update best trial if a better one completes while draining
624+ if (
625+ isinstance (remaining_event , (TrialComplete , TrialPruned ))
626+ and not remaining_event .trial .is_probe
627+ and remaining_event .trial .status == "finished"
628+ and (best_trial is None or remaining_event .trial .score > best_trial .score )
629+ ):
630+ best_trial = remaining_event .trial
631+ logger .success (
632+ f"New best trial (while draining): "
633+ f"id={ best_trial .id } , "
634+ f"step={ best_trial .step } , "
635+ f"score={ best_trial .score :.5f} "
636+ )
637+ yield NewBestTrialFound (
638+ study = self , trials = all_trials , probes = all_probes , trial = best_trial
639+ )
640+
641+ # Stop draining once all in-flight trials are done
642+ if not in_flight_trials :
643+ logger .debug ("All in-flight trials completed, stopping." )
644+ break
600645 break
601646
602647 stop_reason = (
0 commit comments