@@ -303,6 +303,128 @@ def step(
303303 raise e
304304 return iteration
305305
306+ def validate_msg_history (
307+ self ,
308+ call_log : Call ,
309+ msg_history : List [Dict ],
310+ msg_history_schema : StringSchema ,
311+ ):
312+ msg_str = msg_history_string (msg_history )
313+ inputs = Inputs (
314+ llm_output = msg_str ,
315+ )
316+ iteration = Iteration (inputs = inputs )
317+ call_log .iterations .insert (0 , iteration )
318+ validated_msg_history = msg_history_schema .validate (
319+ iteration , msg_str , self .metadata
320+ )
321+ iteration .outputs .validation_output = validated_msg_history
322+ if isinstance (validated_msg_history , ReAsk ):
323+ raise ValidatorError (
324+ f"Message history validation failed: " f"{ validated_msg_history } "
325+ )
326+ if validated_msg_history != msg_str :
327+ raise ValidatorError ("Message history validation failed" )
328+
329+ def prepare_msg_history (
330+ self ,
331+ call_log : Call ,
332+ msg_history : List [Dict ],
333+ prompt_params : Dict ,
334+ msg_history_schema : Optional [StringSchema ],
335+ ):
336+ msg_history = copy .deepcopy (msg_history )
337+ # Format any variables in the message history with the prompt params.
338+ for msg in msg_history :
339+ msg ["content" ] = msg ["content" ].format (** prompt_params )
340+
341+ # validate msg_history
342+ if msg_history_schema is not None :
343+ self .validate_msg_history (call_log , msg_history , msg_history_schema )
344+
345+ return msg_history
346+
347+ def validate_prompt (
348+ self ,
349+ call_log : Call ,
350+ prompt_schema : StringSchema ,
351+ prompt : Prompt ,
352+ ):
353+ inputs = Inputs (
354+ llm_output = prompt .source ,
355+ )
356+ iteration = Iteration (inputs = inputs )
357+ call_log .iterations .insert (0 , iteration )
358+ validated_prompt = prompt_schema .validate (
359+ iteration , prompt .source , self .metadata
360+ )
361+ iteration .outputs .validation_output = validated_prompt
362+ if validated_prompt is None :
363+ raise ValidatorError ("Prompt validation failed" )
364+ if isinstance (validated_prompt , ReAsk ):
365+ raise ValidatorError (f"Prompt validation failed: { validated_prompt } " )
366+ return Prompt (validated_prompt )
367+
368+ def validate_instructions (
369+ self ,
370+ call_log : Call ,
371+ instructions_schema : StringSchema ,
372+ instructions : Instructions ,
373+ ):
374+ inputs = Inputs (
375+ llm_output = instructions .source ,
376+ )
377+ iteration = Iteration (inputs = inputs )
378+ call_log .iterations .insert (0 , iteration )
379+ validated_instructions = instructions_schema .validate (
380+ iteration , instructions .source , self .metadata
381+ )
382+ iteration .outputs .validation_output = validated_instructions
383+ if validated_instructions is None :
384+ raise ValidatorError ("Instructions validation failed" )
385+ if isinstance (validated_instructions , ReAsk ):
386+ raise ValidatorError (
387+ f"Instructions validation failed: { validated_instructions } "
388+ )
389+ return Instructions (validated_instructions )
390+
391+ def prepare_prompt (
392+ self ,
393+ call_log : Call ,
394+ instructions : Optional [Instructions ],
395+ prompt : Prompt ,
396+ prompt_params : Dict ,
397+ api : Union [PromptCallableBase , AsyncPromptCallableBase ],
398+ prompt_schema : Optional [StringSchema ],
399+ instructions_schema : Optional [StringSchema ],
400+ output_schema : Schema ,
401+ ):
402+ if isinstance (prompt , str ):
403+ prompt = Prompt (prompt )
404+
405+ prompt = prompt .format (** prompt_params )
406+
407+ # TODO(shreya): should there be any difference
408+ # to parsing params for prompt?
409+ if instructions is not None and isinstance (instructions , Instructions ):
410+ instructions = instructions .format (** prompt_params )
411+
412+ instructions , prompt = output_schema .preprocess_prompt (
413+ api , instructions , prompt
414+ )
415+
416+ # validate prompt
417+ if prompt_schema is not None and prompt is not None :
418+ prompt = self .validate_prompt (call_log , prompt_schema , prompt )
419+
420+ # validate instructions
421+ if instructions_schema is not None and instructions is not None :
422+ instructions = self .validate_instructions (
423+ call_log , instructions_schema , instructions
424+ )
425+
426+ return instructions , prompt
427+
306428 def prepare (
307429 self ,
308430 call_log : Call ,
@@ -337,32 +459,10 @@ def prepare(
337459 "not supported when using message history."
338460 )
339461 )
340- msg_history = copy .deepcopy (msg_history )
341- # Format any variables in the message history with the prompt params.
342- for msg in msg_history :
343- msg ["content" ] = msg ["content" ].format (** prompt_params )
344-
345462 prompt , instructions = None , None
346-
347- # validate msg_history
348- if msg_history_schema is not None :
349- msg_str = msg_history_string (msg_history )
350- inputs = Inputs (
351- llm_output = msg_str ,
352- )
353- iteration = Iteration (inputs = inputs )
354- call_log .iterations .insert (0 , iteration )
355- validated_msg_history = msg_history_schema .validate (
356- iteration , msg_str , self .metadata
357- )
358- iteration .outputs .validation_output = validated_msg_history
359- if isinstance (validated_msg_history , ReAsk ):
360- raise ValidatorError (
361- f"Message history validation failed: "
362- f"{ validated_msg_history } "
363- )
364- if validated_msg_history != msg_str :
365- raise ValidatorError ("Message history validation failed" )
463+ msg_history = self .prepare_msg_history (
464+ call_log , msg_history , prompt_params , msg_history_schema
465+ )
366466 elif prompt is not None :
367467 if msg_history_schema is not None :
368468 raise UserFacingException (
@@ -371,57 +471,17 @@ def prepare(
371471 "not supported when using prompt/instructions."
372472 )
373473 )
374- if isinstance (prompt , str ):
375- prompt = Prompt (prompt )
376-
377- prompt = prompt .format (** prompt_params )
378-
379- # TODO(shreya): should there be any difference
380- # to parsing params for prompt?
381- if instructions is not None and isinstance (instructions , Instructions ):
382- instructions = instructions .format (** prompt_params )
383-
384- instructions , prompt = output_schema .preprocess_prompt (
385- api , instructions , prompt
474+ msg_history = None
475+ instructions , prompt = self .prepare_prompt (
476+ call_log ,
477+ instructions ,
478+ prompt ,
479+ prompt_params ,
480+ api ,
481+ prompt_schema ,
482+ instructions_schema ,
483+ output_schema ,
386484 )
387-
388- # validate prompt
389- if prompt_schema is not None and prompt is not None :
390- inputs = Inputs (
391- llm_output = prompt .source ,
392- )
393- iteration = Iteration (inputs = inputs )
394- call_log .iterations .insert (0 , iteration )
395- validated_prompt = prompt_schema .validate (
396- iteration , prompt .source , self .metadata
397- )
398- iteration .outputs .validation_output = validated_prompt
399- if validated_prompt is None :
400- raise ValidatorError ("Prompt validation failed" )
401- if isinstance (validated_prompt , ReAsk ):
402- raise ValidatorError (
403- f"Prompt validation failed: { validated_prompt } "
404- )
405- prompt = Prompt (validated_prompt )
406-
407- # validate instructions
408- if instructions_schema is not None and instructions is not None :
409- inputs = Inputs (
410- llm_output = instructions .source ,
411- )
412- iteration = Iteration (inputs = inputs )
413- call_log .iterations .insert (0 , iteration )
414- validated_instructions = instructions_schema .validate (
415- iteration , instructions .source , self .metadata
416- )
417- iteration .outputs .validation_output = validated_instructions
418- if validated_instructions is None :
419- raise ValidatorError ("Instructions validation failed" )
420- if isinstance (validated_instructions , ReAsk ):
421- raise ValidatorError (
422- f"Instructions validation failed: { validated_instructions } "
423- )
424- instructions = Instructions (validated_instructions )
425485 else :
426486 raise UserFacingException (
427487 ValueError ("Prompt or message history must be provided." )
0 commit comments