Skip to content

Commit 1d45840

Browse files
committed
optimize capo survival
1 parent 350b54e commit 1d45840

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

promptolution/optimizers/capo.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,10 @@ def filter_survivors(
320320
Returns:
321321
Tuple[List[Prompt], List[List[float]]]: Filtered candidates and their scores.
322322
"""
323-
filtered_candidates = list(compress(candidates, mask))
324-
filtered_scores = list(compress(scores, mask))
323+
assert len(candidates) == len(mask), "Length of candidates, scores, and mask must be the same."
324+
assert all(len(score) == len(mask) for score in scores), "Length of candidates, scores, and mask must be the same."
325+
326+
filtered_candidates = [c for c, m in zip(candidates, mask) if m]
327+
filtered_scores = [[s for s, m in zip(score, mask) if m] for score in scores]
328+
325329
return filtered_candidates, filtered_scores

0 commit comments

Comments
 (0)