Skip to content

Commit fc0d718

Browse files
Add kwargs instead of single arg to propagate
1 parent 13d1d2d commit fc0d718

File tree

2 files changed

+18
-19
lines changed

2 files changed

+18
-19
lines changed

guardrails/run.py

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -421,23 +421,13 @@ def validate(
421421
index: int,
422422
parsed_output: Any,
423423
output_schema: Schema,
424-
validate_subschema: bool = False,
424+
**kwargs,
425425
):
426426
"""Validate the output."""
427427
with start_action(action_type="validate", index=index) as action:
428-
if isinstance(output_schema, JsonSchema):
429-
validated_output = output_schema.validate(
430-
iteration,
431-
parsed_output,
432-
self.metadata,
433-
validate_subschema=validate_subschema,
434-
)
435-
else:
436-
validated_output = output_schema.validate(
437-
iteration,
438-
parsed_output,
439-
self.metadata,
440-
)
428+
validated_output = output_schema.validate(
429+
iteration, parsed_output, self.metadata, **kwargs
430+
)
441431

442432
action.log(
443433
message_type="info",
@@ -985,13 +975,19 @@ def get_chunk_text(self, chunk: Any, api: Union[PromptCallableBase, None]) -> st
985975
if finished:
986976
chunk_text = ""
987977
else:
988-
chunk_text = chunk["choices"][0]["text"]
978+
if "text" not in chunk["choices"][0]:
979+
chunk_text = ""
980+
else:
981+
chunk_text = chunk["choices"][0]["text"]
989982
elif isinstance(api, OpenAIChatCallable):
990983
finished = chunk["choices"][0]["finish_reason"]
991984
if finished:
992985
chunk_text = ""
993986
else:
994-
chunk_text = chunk["choices"][0]["delta"]["content"]
987+
if "content" not in chunk["choices"][0]["delta"]:
988+
chunk_text = ""
989+
else:
990+
chunk_text = chunk["choices"][0]["delta"]["content"]
995991
else:
996992
try:
997993
chunk_text = chunk

guardrails/schema.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,9 @@ def reask_instructions_template(self, value: Optional[str]) -> None:
110110
else:
111111
self._reask_instructions_template = None
112112

113-
def validate(self, iteration: Iteration, data: Any, metadata: Dict) -> Any:
113+
def validate(
114+
self, iteration: Iteration, data: Any, metadata: Dict, **kwargs
115+
) -> Any:
114116
"""Validate a dictionary of data against the schema.
115117
116118
Args:
@@ -461,7 +463,7 @@ def validate(
461463
iteration: Iteration,
462464
data: Optional[Dict[str, Any]],
463465
metadata: Dict,
464-
validate_subschema: bool = False,
466+
**kwargs,
465467
) -> Any:
466468
"""Validate a dictionary of data against the schema.
467469
@@ -484,7 +486,7 @@ def validate(
484486
validated_response,
485487
prune_extra_keys=True,
486488
coerce_types=True,
487-
validate_subschema=validate_subschema,
489+
validate_subschema=kwargs.get("validate_subschema", False),
488490
):
489491
return SkeletonReAsk(
490492
incorrect_value=validated_response,
@@ -721,6 +723,7 @@ def validate(
721723
iteration: Iteration,
722724
data: Any,
723725
metadata: Dict,
726+
**kwargs,
724727
) -> Any:
725728
"""Validate a dictionary of data against the schema.
726729

0 commit comments

Comments
 (0)