Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions garak/generators/litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ def _call_model(
self.litellm.exceptions.AuthenticationError, # authentication failed for detected or passed `provider`
self.litellm.exceptions.BadRequestError,
self.litellm.exceptions.APIError,
self.litellm.exceptions.InternalServerError,
) as e:
raise BadGeneratorException(
"Unrecoverable error during litellm completion; see log for details"
Expand Down
50 changes: 36 additions & 14 deletions garak/probes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import copy
import json
import logging
import pickle
from collections.abc import Iterable
import random
from typing import Iterable, Union
Expand Down Expand Up @@ -314,7 +315,11 @@ def _execute_attempt(self, this_attempt):
return copy.deepcopy(this_attempt)

def _execute_all(self, attempts) -> Iterable[garak.attempt.Attempt]:
"""handles sending a set of attempt to the generator"""
"""handles sending a set of attempt to the generator

When running with multiprocessing, if pickling errors occur before any
attempts complete, this falls back to a thread pool to continue work.
"""
attempts_completed: Iterable[garak.attempt.Attempt] = []

if (
Expand All @@ -336,20 +341,37 @@ def _execute_all(self, attempts) -> Iterable[garak.attempt.Attempt]:
)

try:
with Pool(pool_size) as attempt_pool:
for result in attempt_pool.imap_unordered(
self._execute_attempt, attempts
):
processed_attempt = self._postprocess_attempt(result)

_config.transient.reportfile.write(
json.dumps(processed_attempt.as_dict(), ensure_ascii=False)
+ "\n"
)
attempts_completed.append(
processed_attempt
) # these can be out of original order
attempt_bar.update(1)
def run_attempt_pool(pool_class):
with pool_class(pool_size) as attempt_pool:
for result in attempt_pool.imap_unordered(
self._execute_attempt, attempts
):
processed_attempt = self._postprocess_attempt(result)

_config.transient.reportfile.write(
json.dumps(
processed_attempt.as_dict(), ensure_ascii=False
)
+ "\n"
)
attempts_completed.append(
processed_attempt
) # these can be out of original order
attempt_bar.update(1)

try:
run_attempt_pool(Pool)
except (TypeError, AttributeError, pickle.PicklingError) as e:
if "pickle" not in str(e).lower() or attempts_completed:
raise
logging.warning(
"Parallel attempt pickling failed (%s); falling back to threads",
e,
)
from multiprocessing.pool import ThreadPool

run_attempt_pool(ThreadPool)
except OSError as o:
if o.errno == 24:
msg = "Parallelisation limit hit. Try reducing parallel_attempts or raising limit (e.g. ulimit -n 4096)"
Expand Down