11import asyncio
22import contextvars
3+ import warnings
34from typing import (
45 Any ,
56 Awaitable ,
2728from guardrails .prompt import Instructions , Prompt
2829from guardrails .rail import Rail
2930from guardrails .run import AsyncRunner , Runner
30- from guardrails .schema import Schema
31+ from guardrails .schema import Schema , StringSchema
3132from guardrails .validators import Validator
3233
3334add_destinations (logger .debug )
@@ -64,9 +65,19 @@ def __init__(
6465 self .base_model = base_model
6566
6667 @property
67- def input_schema (self ) -> Optional [Schema ]:
68+ def prompt_schema (self ) -> Optional [StringSchema ]:
6869 """Return the input schema."""
69- return self .rail .input_schema
70+ return self .rail .prompt_schema
71+
72+ @property
73+ def instructions_schema (self ) -> Optional [StringSchema ]:
74+ """Return the input schema."""
75+ return self .rail .instructions_schema
76+
77+ @property
78+ def msg_history_schema (self ) -> Optional [StringSchema ]:
79+ """Return the input schema."""
80+ return self .rail .msg_history_schema
7081
7182 @property
7283 def output_schema (self ) -> Schema :
@@ -377,7 +388,9 @@ def _call_sync(
377388 prompt = prompt_obj ,
378389 msg_history = msg_history_obj ,
379390 api = get_llm_ask (llm_api , * args , ** kwargs ),
380- input_schema = self .input_schema ,
391+ prompt_schema = self .prompt_schema ,
392+ instructions_schema = self .instructions_schema ,
393+ msg_history_schema = self .msg_history_schema ,
381394 output_schema = self .output_schema ,
382395 num_reasks = num_reasks ,
383396 metadata = metadata ,
@@ -434,7 +447,9 @@ async def _call_async(
434447 prompt = prompt_obj ,
435448 msg_history = msg_history_obj ,
436449 api = get_async_llm_ask (llm_api , * args , ** kwargs ),
437- input_schema = self .input_schema ,
450+ prompt_schema = self .prompt_schema ,
451+ instructions_schema = self .instructions_schema ,
452+ msg_history_schema = self .msg_history_schema ,
438453 output_schema = self .output_schema ,
439454 num_reasks = num_reasks ,
440455 metadata = metadata ,
@@ -610,7 +625,9 @@ def _sync_parse(
610625 prompt = kwargs .pop ("prompt" , None ),
611626 msg_history = kwargs .pop ("msg_history" , None ),
612627 api = get_llm_ask (llm_api , * args , ** kwargs ) if llm_api else None ,
613- input_schema = None ,
628+ prompt_schema = self .prompt_schema ,
629+ instructions_schema = self .instructions_schema ,
630+ msg_history_schema = self .msg_history_schema ,
614631 output_schema = self .output_schema ,
615632 num_reasks = num_reasks ,
616633 metadata = metadata ,
@@ -650,7 +667,9 @@ async def _async_parse(
650667 prompt = kwargs .pop ("prompt" , None ),
651668 msg_history = kwargs .pop ("msg_history" , None ),
652669 api = get_async_llm_ask (llm_api , * args , ** kwargs ) if llm_api else None ,
653- input_schema = None ,
670+ prompt_schema = self .prompt_schema ,
671+ instructions_schema = self .instructions_schema ,
672+ msg_history_schema = self .msg_history_schema ,
654673 output_schema = self .output_schema ,
655674 num_reasks = num_reasks ,
656675 metadata = metadata ,
@@ -663,3 +682,54 @@ async def _async_parse(
663682 )
664683
665684 return ValidationOutcome [OT ].from_guard_history (call , error_message )
685+
686+ def with_prompt_validation (
687+ self ,
688+ validators : Sequence [Validator ],
689+ ):
690+ """Add prompt validation to the Guard.
691+
692+ Args:
693+ validators: The validators to add to the prompt.
694+ """
695+ if self .rail .prompt_schema :
696+ warnings .warn ("Overriding existing prompt validators." )
697+ schema = StringSchema .from_string (
698+ validators = validators ,
699+ )
700+ self .rail .prompt_schema = schema
701+ return self
702+
703+ def with_instructions_validation (
704+ self ,
705+ validators : Sequence [Validator ],
706+ ):
707+ """Add instructions validation to the Guard.
708+
709+ Args:
710+ validators: The validators to add to the instructions.
711+ """
712+ if self .rail .instructions_schema :
713+ warnings .warn ("Overriding existing instructions validators." )
714+ schema = StringSchema .from_string (
715+ validators = validators ,
716+ )
717+ self .rail .instructions_schema = schema
718+ return self
719+
720+ def with_msg_history_validation (
721+ self ,
722+ validators : Sequence [Validator ],
723+ ):
724+ """Add msg_history validation to the Guard.
725+
726+ Args:
727+ validators: The validators to add to the msg_history.
728+ """
729+ if self .rail .msg_history_schema :
730+ warnings .warn ("Overriding existing msg_history validators." )
731+ schema = StringSchema .from_string (
732+ validators = validators ,
733+ )
734+ self .rail .msg_history_schema = schema
735+ return self
0 commit comments