Skip to content

Commit a1f7893

Browse files
committed
fix some more tests
1 parent 18591b7 commit a1f7893

File tree

3 files changed

+15
-14
lines changed

3 files changed

+15
-14
lines changed

guardrails/run/async_runner.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import copy
22
from functools import partial
3-
from typing import Any, Dict, List, Optional, Union, cast
3+
from typing import Any, Dict, List, Optional, cast
44

55

66
from guardrails import validator_service
@@ -11,7 +11,7 @@
1111
from guardrails.llm_providers import AsyncPromptCallableBase
1212
from guardrails.logger import set_scope
1313
from guardrails.run.runner import Runner
14-
from guardrails.run.utils import messages_source, messages_string
14+
from guardrails.run.utils import messages_source
1515
from guardrails.schema.validator import schema_validation
1616
from guardrails.hub_telemetry.hub_tracing import async_trace
1717
from guardrails.types.inputs import MessageHistory
@@ -25,6 +25,7 @@
2525
from guardrails.constants import fail_status
2626
from guardrails.prompt import Prompt
2727

28+
2829
class AsyncRunner(Runner):
2930
def __init__(
3031
self,
@@ -331,9 +332,7 @@ async def prepare_messages(
331332
formatted_messages.append(msg_copy)
332333

333334
if "messages" in self.validation_map:
334-
await self.validate_messages(
335-
call_log, formatted_messages, attempt_number
336-
)
335+
await self.validate_messages(call_log, formatted_messages, attempt_number)
337336

338337
return formatted_messages
339338

@@ -348,9 +347,11 @@ async def validate_messages(
348347
else msg["content"]
349348
)
350349
inputs = Inputs(
351-
llm_output=content,
352-
)
353-
iteration = Iteration(call_id=call_log.id, index=attempt_number, inputs=inputs)
350+
llm_output=content,
351+
)
352+
iteration = Iteration(
353+
call_id=call_log.id, index=attempt_number, inputs=inputs
354+
)
354355
call_log.iterations.insert(0, iteration)
355356
value, _metadata = await validator_service.async_validate(
356357
value=content,
@@ -364,7 +365,7 @@ async def validate_messages(
364365
validated_msg = validator_service.post_process_validation(
365366
value, attempt_number, iteration, OutputTypes.STRING
366367
)
367-
368+
368369
iteration.outputs.validation_response = validated_msg
369370

370371
if isinstance(validated_msg, ReAsk):
@@ -374,4 +375,4 @@ async def validate_messages(
374375

375376
msg["content"] = cast(str, validated_msg)
376377

377-
return messages # type: ignore
378+
return messages # type: ignore

guardrails/run/async_stream_runner.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -205,13 +205,13 @@ def get_chunk_text(self, chunk: Any, api: Union[PromptCallableBase, None]) -> st
205205
content = chunk.choices[0].text
206206
if not finished and content:
207207
chunk_text = content
208-
except Exception as e:
208+
except Exception:
209209
try:
210210
finished = chunk.choices[0].finish_reason
211211
content = chunk.choices[0].delta.content
212212
if not finished and content:
213213
chunk_text = content
214-
except Exception as e:
214+
except Exception:
215215
try:
216216
chunk_text = chunk
217217
except Exception as e:

guardrails/run/stream_runner.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -253,13 +253,13 @@ def get_chunk_text(self, chunk: Any, api: Union[PromptCallableBase, None]) -> st
253253
content = chunk.choices[0].text
254254
if not finished and content:
255255
chunk_text = content
256-
except Exception as e:
256+
except Exception:
257257
try:
258258
finished = chunk.choices[0].finish_reason
259259
content = chunk.choices[0].delta.content
260260
if not finished and content:
261261
chunk_text = content
262-
except Exception as e:
262+
except Exception:
263263
try:
264264
chunk_text = chunk
265265
except Exception as e:

0 commit comments

Comments
 (0)