Skip to content

Commit b5d9f6e

Browse files
committed
run: Split long control flow into submethods
1 parent 645c7c6 commit b5d9f6e

File tree

1 file changed

+135
-75
lines changed

1 file changed

+135
-75
lines changed

guardrails/run.py

Lines changed: 135 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,128 @@ def step(
303303
raise e
304304
return iteration
305305

306+
def validate_msg_history(
307+
self,
308+
call_log: Call,
309+
msg_history: List[Dict],
310+
msg_history_schema: StringSchema,
311+
):
312+
msg_str = msg_history_string(msg_history)
313+
inputs = Inputs(
314+
llm_output=msg_str,
315+
)
316+
iteration = Iteration(inputs=inputs)
317+
call_log.iterations.insert(0, iteration)
318+
validated_msg_history = msg_history_schema.validate(
319+
iteration, msg_str, self.metadata
320+
)
321+
iteration.outputs.validation_output = validated_msg_history
322+
if isinstance(validated_msg_history, ReAsk):
323+
raise ValidatorError(
324+
f"Message history validation failed: " f"{validated_msg_history}"
325+
)
326+
if validated_msg_history != msg_str:
327+
raise ValidatorError("Message history validation failed")
328+
329+
def prepare_msg_history(
330+
self,
331+
call_log: Call,
332+
msg_history: List[Dict],
333+
prompt_params: Dict,
334+
msg_history_schema: Optional[StringSchema],
335+
):
336+
msg_history = copy.deepcopy(msg_history)
337+
# Format any variables in the message history with the prompt params.
338+
for msg in msg_history:
339+
msg["content"] = msg["content"].format(**prompt_params)
340+
341+
# validate msg_history
342+
if msg_history_schema is not None:
343+
self.validate_msg_history(call_log, msg_history, msg_history_schema)
344+
345+
return msg_history
346+
347+
def validate_prompt(
348+
self,
349+
call_log: Call,
350+
prompt_schema: StringSchema,
351+
prompt: Prompt,
352+
):
353+
inputs = Inputs(
354+
llm_output=prompt.source,
355+
)
356+
iteration = Iteration(inputs=inputs)
357+
call_log.iterations.insert(0, iteration)
358+
validated_prompt = prompt_schema.validate(
359+
iteration, prompt.source, self.metadata
360+
)
361+
iteration.outputs.validation_output = validated_prompt
362+
if validated_prompt is None:
363+
raise ValidatorError("Prompt validation failed")
364+
if isinstance(validated_prompt, ReAsk):
365+
raise ValidatorError(f"Prompt validation failed: {validated_prompt}")
366+
return Prompt(validated_prompt)
367+
368+
def validate_instructions(
369+
self,
370+
call_log: Call,
371+
instructions_schema: StringSchema,
372+
instructions: Instructions,
373+
):
374+
inputs = Inputs(
375+
llm_output=instructions.source,
376+
)
377+
iteration = Iteration(inputs=inputs)
378+
call_log.iterations.insert(0, iteration)
379+
validated_instructions = instructions_schema.validate(
380+
iteration, instructions.source, self.metadata
381+
)
382+
iteration.outputs.validation_output = validated_instructions
383+
if validated_instructions is None:
384+
raise ValidatorError("Instructions validation failed")
385+
if isinstance(validated_instructions, ReAsk):
386+
raise ValidatorError(
387+
f"Instructions validation failed: {validated_instructions}"
388+
)
389+
return Instructions(validated_instructions)
390+
391+
def prepare_prompt(
392+
self,
393+
call_log: Call,
394+
instructions: Optional[Instructions],
395+
prompt: Prompt,
396+
prompt_params: Dict,
397+
api: Union[PromptCallableBase, AsyncPromptCallableBase],
398+
prompt_schema: Optional[StringSchema],
399+
instructions_schema: Optional[StringSchema],
400+
output_schema: Schema,
401+
):
402+
if isinstance(prompt, str):
403+
prompt = Prompt(prompt)
404+
405+
prompt = prompt.format(**prompt_params)
406+
407+
# TODO(shreya): should there be any difference
408+
# to parsing params for prompt?
409+
if instructions is not None and isinstance(instructions, Instructions):
410+
instructions = instructions.format(**prompt_params)
411+
412+
instructions, prompt = output_schema.preprocess_prompt(
413+
api, instructions, prompt
414+
)
415+
416+
# validate prompt
417+
if prompt_schema is not None and prompt is not None:
418+
prompt = self.validate_prompt(call_log, prompt_schema, prompt)
419+
420+
# validate instructions
421+
if instructions_schema is not None and instructions is not None:
422+
instructions = self.validate_instructions(
423+
call_log, instructions_schema, instructions
424+
)
425+
426+
return instructions, prompt
427+
306428
def prepare(
307429
self,
308430
call_log: Call,
@@ -337,32 +459,10 @@ def prepare(
337459
"not supported when using message history."
338460
)
339461
)
340-
msg_history = copy.deepcopy(msg_history)
341-
# Format any variables in the message history with the prompt params.
342-
for msg in msg_history:
343-
msg["content"] = msg["content"].format(**prompt_params)
344-
345462
prompt, instructions = None, None
346-
347-
# validate msg_history
348-
if msg_history_schema is not None:
349-
msg_str = msg_history_string(msg_history)
350-
inputs = Inputs(
351-
llm_output=msg_str,
352-
)
353-
iteration = Iteration(inputs=inputs)
354-
call_log.iterations.insert(0, iteration)
355-
validated_msg_history = msg_history_schema.validate(
356-
iteration, msg_str, self.metadata
357-
)
358-
iteration.outputs.validation_output = validated_msg_history
359-
if isinstance(validated_msg_history, ReAsk):
360-
raise ValidatorError(
361-
f"Message history validation failed: "
362-
f"{validated_msg_history}"
363-
)
364-
if validated_msg_history != msg_str:
365-
raise ValidatorError("Message history validation failed")
463+
msg_history = self.prepare_msg_history(
464+
call_log, msg_history, prompt_params, msg_history_schema
465+
)
366466
elif prompt is not None:
367467
if msg_history_schema is not None:
368468
raise UserFacingException(
@@ -371,57 +471,17 @@ def prepare(
371471
"not supported when using prompt/instructions."
372472
)
373473
)
374-
if isinstance(prompt, str):
375-
prompt = Prompt(prompt)
376-
377-
prompt = prompt.format(**prompt_params)
378-
379-
# TODO(shreya): should there be any difference
380-
# to parsing params for prompt?
381-
if instructions is not None and isinstance(instructions, Instructions):
382-
instructions = instructions.format(**prompt_params)
383-
384-
instructions, prompt = output_schema.preprocess_prompt(
385-
api, instructions, prompt
474+
msg_history = None
475+
instructions, prompt = self.prepare_prompt(
476+
call_log,
477+
instructions,
478+
prompt,
479+
prompt_params,
480+
api,
481+
prompt_schema,
482+
instructions_schema,
483+
output_schema,
386484
)
387-
388-
# validate prompt
389-
if prompt_schema is not None and prompt is not None:
390-
inputs = Inputs(
391-
llm_output=prompt.source,
392-
)
393-
iteration = Iteration(inputs=inputs)
394-
call_log.iterations.insert(0, iteration)
395-
validated_prompt = prompt_schema.validate(
396-
iteration, prompt.source, self.metadata
397-
)
398-
iteration.outputs.validation_output = validated_prompt
399-
if validated_prompt is None:
400-
raise ValidatorError("Prompt validation failed")
401-
if isinstance(validated_prompt, ReAsk):
402-
raise ValidatorError(
403-
f"Prompt validation failed: {validated_prompt}"
404-
)
405-
prompt = Prompt(validated_prompt)
406-
407-
# validate instructions
408-
if instructions_schema is not None and instructions is not None:
409-
inputs = Inputs(
410-
llm_output=instructions.source,
411-
)
412-
iteration = Iteration(inputs=inputs)
413-
call_log.iterations.insert(0, iteration)
414-
validated_instructions = instructions_schema.validate(
415-
iteration, instructions.source, self.metadata
416-
)
417-
iteration.outputs.validation_output = validated_instructions
418-
if validated_instructions is None:
419-
raise ValidatorError("Instructions validation failed")
420-
if isinstance(validated_instructions, ReAsk):
421-
raise ValidatorError(
422-
f"Instructions validation failed: {validated_instructions}"
423-
)
424-
instructions = Instructions(validated_instructions)
425485
else:
426486
raise UserFacingException(
427487
ValueError("Prompt or message history must be provided.")

0 commit comments

Comments
 (0)