11import asyncio
22import contextvars
3+ import warnings
34from typing import (
45 Any ,
56 Awaitable ,
2829from guardrails .prompt import Instructions , Prompt
2930from guardrails .rail import Rail
3031from guardrails .run import AsyncRunner , Runner , StreamRunner
31- from guardrails .schema import Schema
32+ from guardrails .schema import Schema , StringSchema
3233from guardrails .validators import Validator
3334
3435add_destinations (logger .debug )
@@ -65,9 +66,19 @@ def __init__(
6566 self .base_model = base_model
6667
6768 @property
68- def input_schema (self ) -> Optional [Schema ]:
69+ def prompt_schema (self ) -> Optional [StringSchema ]:
6970 """Return the input schema."""
70- return self .rail .input_schema
71+ return self .rail .prompt_schema
72+
73+ @property
74+ def instructions_schema (self ) -> Optional [StringSchema ]:
75+ """Return the input schema."""
76+ return self .rail .instructions_schema
77+
78+ @property
79+ def msg_history_schema (self ) -> Optional [StringSchema ]:
80+ """Return the input schema."""
81+ return self .rail .msg_history_schema
7182
7283 @property
7384 def output_schema (self ) -> Schema :
@@ -458,7 +469,9 @@ async def _call_async(
458469 prompt = prompt_obj ,
459470 msg_history = msg_history_obj ,
460471 api = get_async_llm_ask (llm_api , * args , ** kwargs ),
461- input_schema = self .input_schema ,
472+ prompt_schema = self .prompt_schema ,
473+ instructions_schema = self .instructions_schema ,
474+ msg_history_schema = self .msg_history_schema ,
462475 output_schema = self .output_schema ,
463476 num_reasks = num_reasks ,
464477 metadata = metadata ,
@@ -634,7 +647,9 @@ def _sync_parse(
634647 prompt = kwargs .pop ("prompt" , None ),
635648 msg_history = kwargs .pop ("msg_history" , None ),
636649 api = get_llm_ask (llm_api , * args , ** kwargs ) if llm_api else None ,
637- input_schema = None ,
650+ prompt_schema = self .prompt_schema ,
651+ instructions_schema = self .instructions_schema ,
652+ msg_history_schema = self .msg_history_schema ,
638653 output_schema = self .output_schema ,
639654 num_reasks = num_reasks ,
640655 metadata = metadata ,
@@ -674,7 +689,9 @@ async def _async_parse(
674689 prompt = kwargs .pop ("prompt" , None ),
675690 msg_history = kwargs .pop ("msg_history" , None ),
676691 api = get_async_llm_ask (llm_api , * args , ** kwargs ) if llm_api else None ,
677- input_schema = None ,
692+ prompt_schema = self .prompt_schema ,
693+ instructions_schema = self .instructions_schema ,
694+ msg_history_schema = self .msg_history_schema ,
678695 output_schema = self .output_schema ,
679696 num_reasks = num_reasks ,
680697 metadata = metadata ,
@@ -687,3 +704,54 @@ async def _async_parse(
687704 )
688705
689706 return ValidationOutcome [OT ].from_guard_history (call , error_message )
707+
708+ def with_prompt_validation (
709+ self ,
710+ validators : Sequence [Validator ],
711+ ):
712+ """Add prompt validation to the Guard.
713+
714+ Args:
715+ validators: The validators to add to the prompt.
716+ """
717+ if self .rail .prompt_schema :
718+ warnings .warn ("Overriding existing prompt validators." )
719+ schema = StringSchema .from_string (
720+ validators = validators ,
721+ )
722+ self .rail .prompt_schema = schema
723+ return self
724+
725+ def with_instructions_validation (
726+ self ,
727+ validators : Sequence [Validator ],
728+ ):
729+ """Add instructions validation to the Guard.
730+
731+ Args:
732+ validators: The validators to add to the instructions.
733+ """
734+ if self .rail .instructions_schema :
735+ warnings .warn ("Overriding existing instructions validators." )
736+ schema = StringSchema .from_string (
737+ validators = validators ,
738+ )
739+ self .rail .instructions_schema = schema
740+ return self
741+
742+ def with_msg_history_validation (
743+ self ,
744+ validators : Sequence [Validator ],
745+ ):
746+ """Add msg_history validation to the Guard.
747+
748+ Args:
749+ validators: The validators to add to the msg_history.
750+ """
751+ if self .rail .msg_history_schema :
752+ warnings .warn ("Overriding existing msg_history validators." )
753+ schema = StringSchema .from_string (
754+ validators = validators ,
755+ )
756+ self .rail .msg_history_schema = schema
757+ return self
0 commit comments