@@ -89,14 +89,14 @@ class PromptConfig(BaseModel):
8989 model : str | None = None
9090 config : GenerationCommonConfig | dict [str , Any ] | None = None
9191 description : str | None = None
92- input_schema : type | dict [str , Any ] | None = None
92+ input_schema : type | dict [str , Any ] | str | None = None
9393 system : str | Part | list [Part ] | Callable | None = None
9494 prompt : str | Part | list [Part ] | Callable | None = None
9595 messages : str | list [Message ] | Callable | None = None
9696 output_format : str | None = None
9797 output_content_type : str | None = None
9898 output_instructions : bool | str | None = None
99- output_schema : type | dict [str , Any ] | None = None
99+ output_schema : type | dict [str , Any ] | str | None = None
100100 output_constrained : bool | None = None
101101 max_turns : int | None = None
102102 return_tool_requests : bool | None = None
@@ -148,14 +148,14 @@ def __init__(
148148 model: The model to use for generation.
149149 config: The generation configuration.
150150 description: A description of the prompt.
151- input_schema: The input schema for the prompt.
152- system: The system message for the prompt.
153- prompt: The user prompt.
154- messages: A list of messages to include in the prompt.
155- output_format: The output format.
156- output_content_type: The output content type.
151+ input_schema: type | dict[str, Any] | str | None = None,
152+ system: str | Part | list[Part] | Callable | None = None,
153+ prompt: str | Part | list[Part] | Callable | None = None,
154+ messages: str | list[Message] | Callable | None = None,
155+ output_format: str | None = None,
156+ output_content_type: str | None = None,
157157 output_instructions: Instructions for formatting the output.
158- output_schema: The output schema.
158+ output_schema: type | dict[str, Any] | str | None = None,
159159 output_constrained: Whether the output should be constrained to the output schema.
160160 max_turns: The maximum number of turns in a conversation.
161161 return_tool_requests: Whether to return tool requests.
@@ -387,14 +387,14 @@ def define_prompt(
387387 model : str | None = None ,
388388 config : GenerationCommonConfig | dict [str , Any ] | None = None ,
389389 description : str | None = None ,
390- input_schema : type | dict [str , Any ] | None = None ,
390+ input_schema : type | dict [str , Any ] | str | None = None ,
391391 system : str | Part | list [Part ] | Callable | None = None ,
392392 prompt : str | Part | list [Part ] | Callable | None = None ,
393393 messages : str | list [Message ] | Callable | None = None ,
394394 output_format : str | None = None ,
395395 output_content_type : str | None = None ,
396396 output_instructions : bool | str | None = None ,
397- output_schema : type | dict [str , Any ] | None = None ,
397+ output_schema : type | dict [str , Any ] | str | None = None ,
398398 output_constrained : bool | None = None ,
399399 max_turns : int | None = None ,
400400 return_tool_requests : bool | None = None ,
@@ -541,7 +541,18 @@ async def to_generate_action_options(registry: Registry, options: PromptConfig)
541541 if options .output_instructions is not None :
542542 output .instructions = options .output_instructions
543543 if options .output_schema :
544- output .json_schema = to_json_schema (options .output_schema )
544+ if isinstance (options .output_schema , str ):
545+ resolved_schema = registry .lookup_schema (options .output_schema )
546+ if resolved_schema :
547+ output .json_schema = resolved_schema
548+ elif options .output_constrained :
549+ # If we have a schema name but can't resolve it, and constrained is True,
550+ # we should probably error or warn. But for now, we might pass None or
551+ # try one last look up?
552+ # Actually, lookup_schema handles it. If None, we can't do much.
553+ pass
554+ else :
555+ output .json_schema = to_json_schema (options .output_schema )
545556 if options .output_constrained is not None :
546557 output .constrained = options .output_constrained
547558
@@ -940,6 +951,35 @@ def define_helper(registry: Registry, name: str, fn: Callable) -> None:
940951 logger .debug (f'Registered Dotprompt helper "{ name } "' )
941952
942953
954+ def define_schema (registry : Registry , name : str , schema : type ) -> None :
955+ """Register a Pydantic schema for use in prompts.
956+
957+ Schemas registered with this function can be referenced by name in
958+ .prompt files using the `output.schema` field.
959+
960+ Args:
961+ registry: The registry to register the schema in.
962+ name: The name of the schema.
963+ schema: The Pydantic model class to register.
964+
965+ Example:
966+ ```python
967+ from genkit.blocks.prompt import define_schema
968+
969+ define_schema(registry, 'Recipe', Recipe)
970+ ```
971+
972+ Then in a .prompt file:
973+ ```yaml
974+ output:
975+ schema: Recipe
976+ ```
977+ """
978+ json_schema = to_json_schema (schema )
979+ registry .register_schema (name , json_schema )
980+ logger .debug (f'Registered schema "{ name } "' )
981+
982+
943983def load_prompt (registry : Registry , path : Path , filename : str , prefix : str = '' , ns : str = '' ) -> None :
944984 """Load a single prompt file and register it in the registry.
945985
@@ -1001,23 +1041,46 @@ async def load_prompt_metadata():
10011041
10021042 # Convert Pydantic model to dict if needed
10031043 if hasattr (prompt_metadata , 'model_dump' ):
1004- prompt_metadata_dict = prompt_metadata .model_dump ()
1044+ prompt_metadata_dict = prompt_metadata .model_dump (by_alias = True )
10051045 elif hasattr (prompt_metadata , 'dict' ):
1006- prompt_metadata_dict = prompt_metadata .dict ()
1046+ prompt_metadata_dict = prompt_metadata .dict (by_alias = True )
10071047 else :
10081048 # Already a dict
10091049 prompt_metadata_dict = prompt_metadata
10101050
1051+ # Ensure raw metadata is available (critical for lazy schema resolution)
1052+ if hasattr (prompt_metadata , 'raw' ):
1053+ prompt_metadata_dict ['raw' ] = prompt_metadata .raw
1054+
10111055 if variant :
10121056 prompt_metadata_dict ['variant' ] = variant
10131057
1058+ # Fallback for model if not present (Dotprompt issue)
1059+ if not prompt_metadata_dict .get ('model' ):
1060+ raw_model = (prompt_metadata_dict .get ('raw' ) or {}).get ('model' )
1061+ if raw_model :
1062+ prompt_metadata_dict ['model' ] = raw_model
1063+
10141064 # Clean up null descriptions
10151065 output = prompt_metadata_dict .get ('output' )
1066+ schema = None
10161067 if output and isinstance (output , dict ):
10171068 schema = output .get ('schema' )
10181069 if schema and isinstance (schema , dict ) and schema .get ('description' ) is None :
10191070 schema .pop ('description' , None )
10201071
1072+ if not schema :
1073+ # Fallback to raw schema name if schema definition is missing
1074+ raw_schema = (prompt_metadata_dict .get ('raw' ) or {}).get ('output' , {}).get ('schema' )
1075+ if isinstance (raw_schema , str ):
1076+ schema = raw_schema
1077+ # output might be None if it wasn't in parsed config
1078+ if not output :
1079+ output = {'schema' : schema }
1080+ prompt_metadata_dict ['output' ] = output
1081+ elif isinstance (output , dict ):
1082+ output ['schema' ] = schema
1083+
10211084 input_schema = prompt_metadata_dict .get ('input' )
10221085 if input_schema and isinstance (input_schema , dict ):
10231086 schema = input_schema .get ('schema' )
@@ -1026,6 +1089,7 @@ async def load_prompt_metadata():
10261089
10271090 # Build metadata structure
10281091 metadata = {
1092+ ** prompt_metadata_dict ,
10291093 ** prompt_metadata_dict .get ('metadata' , {}),
10301094 'type' : 'prompt' ,
10311095 'prompt' : {
@@ -1078,6 +1142,7 @@ async def create_prompt_from_file():
10781142 description = metadata .get ('description' ),
10791143 input_schema = metadata .get ('input' , {}).get ('jsonSchema' ),
10801144 output_schema = metadata .get ('output' , {}).get ('jsonSchema' ),
1145+ output_constrained = True if metadata .get ('output' , {}).get ('jsonSchema' ) else None ,
10811146 output_format = metadata .get ('output' , {}).get ('format' ),
10821147 messages = metadata .get ('messages' ),
10831148 max_turns = metadata .get ('maxTurns' ),
0 commit comments