Skip to content

Commit 6728ece

Browse files
fix(py): Updated Registry class to support schema registration and resolution
1 parent 3cfa0da commit 6728ece

File tree

12 files changed

+226
-42
lines changed

12 files changed

+226
-42
lines changed

py/packages/genkit/src/genkit/ai/_registry.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
define_helper,
6060
define_partial,
6161
define_prompt,
62+
define_schema,
6263
lookup_prompt,
6364
)
6465
from genkit.blocks.reranker import (
@@ -207,6 +208,33 @@ def define_partial(self, name: str, source: str) -> None:
207208
"""
208209
define_partial(self.registry, name, source)
209210

211+
def define_schema(self, name: str, schema: type) -> type:
212+
"""Register a Pydantic schema for use in prompts.
213+
214+
Schemas registered with this method can be referenced by name in
215+
.prompt files using the `output.schema` field.
216+
217+
Args:
218+
name: The name to register the schema under.
219+
schema: The Pydantic model class to register.
220+
221+
Returns:
222+
The schema that was registered (for convenience).
223+
224+
Example:
225+
```python
226+
RecipeSchema = ai.define_schema('Recipe', Recipe)
227+
```
228+
229+
Then in a .prompt file:
230+
```yaml
231+
output:
232+
schema: Recipe
233+
```
234+
"""
235+
define_schema(self.registry, name, schema)
236+
return schema
237+
210238
def tool(self, name: str | None = None, description: str | None = None) -> Callable[[Callable], Callable]:
211239
"""Decorator to register a function as a tool.
212240
@@ -698,14 +726,14 @@ def define_prompt(
698726
model: str | None = None,
699727
config: GenerationCommonConfig | dict[str, Any] | None = None,
700728
description: str | None = None,
701-
input_schema: type | dict[str, Any] | None = None,
729+
input_schema: type | dict[str, Any] | str | None = None,
702730
system: str | Part | list[Part] | Callable | None = None,
703731
prompt: str | Part | list[Part] | Callable | None = None,
704732
messages: str | list[Message] | Callable | None = None,
705733
output_format: str | None = None,
706734
output_content_type: str | None = None,
707735
output_instructions: bool | str | None = None,
708-
output_schema: type | dict[str, Any] | None = None,
736+
output_schema: type | dict[str, Any] | str | None = None,
709737
output_constrained: bool | None = None,
710738
max_turns: int | None = None,
711739
return_tool_requests: bool | None = None,

py/packages/genkit/src/genkit/blocks/formats/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
"""Genkit format package. Provides implementation for various formats like json, jsonl, etc."""
1919

2020
from genkit.blocks.formats.json import JsonFormat
21+
from genkit.blocks.formats.text import TextFormat
2122
from genkit.blocks.formats.types import FormatDef, Formatter, FormatterConfig
2223

2324

@@ -26,13 +27,14 @@ def package_name() -> str:
2627
return 'genkit.blocks.formats'
2728

2829

29-
built_in_formats = [JsonFormat()]
30+
built_in_formats = [JsonFormat(), TextFormat()]
3031

3132

3233
__all__ = [
3334
FormatDef.__name__,
3435
Formatter.__name__,
3536
FormatterConfig.__name__,
3637
JsonFormat.__name__,
38+
TextFormat.__name__,
3739
package_name.__name__,
3840
]
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
#
15+
# SPDX-License-Identifier: Apache-2.0
16+
17+
"""Implementation of text output format."""
18+
19+
from typing import Any
20+
21+
from genkit.blocks.formats.types import FormatDef, Formatter, FormatterConfig
22+
from genkit.blocks.model import (
23+
GenerateResponseWrapper,
24+
MessageWrapper,
25+
)
26+
27+
28+
class TextFormat(FormatDef):
29+
"""Defines a text format for use with AI models.
30+
31+
This class provides functionality for parsing and formatting text data
32+
to interact with AI models.
33+
"""
34+
35+
def __init__(self):
36+
"""Initializes a TextFormat instance."""
37+
super().__init__(
38+
'text',
39+
FormatterConfig(
40+
format='text',
41+
content_type='text/plain',
42+
constrained=None,
43+
default_instructions=False,
44+
),
45+
)
46+
47+
def handle(self, schema: dict[str, Any] | None) -> Formatter:
48+
"""Creates a Formatter for handling text data.
49+
50+
Args:
51+
schema: Optional schema (ignored for text).
52+
53+
Returns:
54+
A Formatter instance configured for text handling.
55+
"""
56+
57+
def message_parser(msg: MessageWrapper):
58+
"""Extracts text from a Message object."""
59+
return msg.text
60+
61+
def chunk_parser(chunk: GenerateResponseWrapper):
62+
"""Extracts text from a GenerateResponseWrapper object."""
63+
return chunk.accumulated_text
64+
65+
return Formatter(
66+
chunk_parser=chunk_parser,
67+
message_parser=message_parser,
68+
instructions=None,
69+
)

py/packages/genkit/src/genkit/blocks/prompt.py

Lines changed: 79 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -89,14 +89,14 @@ class PromptConfig(BaseModel):
8989
model: str | None = None
9090
config: GenerationCommonConfig | dict[str, Any] | None = None
9191
description: str | None = None
92-
input_schema: type | dict[str, Any] | None = None
92+
input_schema: type | dict[str, Any] | str | None = None
9393
system: str | Part | list[Part] | Callable | None = None
9494
prompt: str | Part | list[Part] | Callable | None = None
9595
messages: str | list[Message] | Callable | None = None
9696
output_format: str | None = None
9797
output_content_type: str | None = None
9898
output_instructions: bool | str | None = None
99-
output_schema: type | dict[str, Any] | None = None
99+
output_schema: type | dict[str, Any] | str | None = None
100100
output_constrained: bool | None = None
101101
max_turns: int | None = None
102102
return_tool_requests: bool | None = None
@@ -148,14 +148,14 @@ def __init__(
148148
model: The model to use for generation.
149149
config: The generation configuration.
150150
description: A description of the prompt.
151-
input_schema: The input schema for the prompt.
152-
system: The system message for the prompt.
153-
prompt: The user prompt.
154-
messages: A list of messages to include in the prompt.
155-
output_format: The output format.
156-
output_content_type: The output content type.
151+
input_schema: type | dict[str, Any] | str | None = None,
152+
system: str | Part | list[Part] | Callable | None = None,
153+
prompt: str | Part | list[Part] | Callable | None = None,
154+
messages: str | list[Message] | Callable | None = None,
155+
output_format: str | None = None,
156+
output_content_type: str | None = None,
157157
output_instructions: Instructions for formatting the output.
158-
output_schema: The output schema.
158+
output_schema: type | dict[str, Any] | str | None = None,
159159
output_constrained: Whether the output should be constrained to the output schema.
160160
max_turns: The maximum number of turns in a conversation.
161161
return_tool_requests: Whether to return tool requests.
@@ -387,14 +387,14 @@ def define_prompt(
387387
model: str | None = None,
388388
config: GenerationCommonConfig | dict[str, Any] | None = None,
389389
description: str | None = None,
390-
input_schema: type | dict[str, Any] | None = None,
390+
input_schema: type | dict[str, Any] | str | None = None,
391391
system: str | Part | list[Part] | Callable | None = None,
392392
prompt: str | Part | list[Part] | Callable | None = None,
393393
messages: str | list[Message] | Callable | None = None,
394394
output_format: str | None = None,
395395
output_content_type: str | None = None,
396396
output_instructions: bool | str | None = None,
397-
output_schema: type | dict[str, Any] | None = None,
397+
output_schema: type | dict[str, Any] | str | None = None,
398398
output_constrained: bool | None = None,
399399
max_turns: int | None = None,
400400
return_tool_requests: bool | None = None,
@@ -541,7 +541,18 @@ async def to_generate_action_options(registry: Registry, options: PromptConfig)
541541
if options.output_instructions is not None:
542542
output.instructions = options.output_instructions
543543
if options.output_schema:
544-
output.json_schema = to_json_schema(options.output_schema)
544+
if isinstance(options.output_schema, str):
545+
resolved_schema = registry.lookup_schema(options.output_schema)
546+
if resolved_schema:
547+
output.json_schema = resolved_schema
548+
elif options.output_constrained:
549+
# If we have a schema name but can't resolve it, and constrained is True,
550+
# we should probably error or warn. But for now, we might pass None or
551+
# try one last look up?
552+
# Actually, lookup_schema handles it. If None, we can't do much.
553+
pass
554+
else:
555+
output.json_schema = to_json_schema(options.output_schema)
545556
if options.output_constrained is not None:
546557
output.constrained = options.output_constrained
547558

@@ -940,6 +951,35 @@ def define_helper(registry: Registry, name: str, fn: Callable) -> None:
940951
logger.debug(f'Registered Dotprompt helper "{name}"')
941952

942953

954+
def define_schema(registry: Registry, name: str, schema: type) -> None:
955+
"""Register a Pydantic schema for use in prompts.
956+
957+
Schemas registered with this function can be referenced by name in
958+
.prompt files using the `output.schema` field.
959+
960+
Args:
961+
registry: The registry to register the schema in.
962+
name: The name of the schema.
963+
schema: The Pydantic model class to register.
964+
965+
Example:
966+
```python
967+
from genkit.blocks.prompt import define_schema
968+
969+
define_schema(registry, 'Recipe', Recipe)
970+
```
971+
972+
Then in a .prompt file:
973+
```yaml
974+
output:
975+
schema: Recipe
976+
```
977+
"""
978+
json_schema = to_json_schema(schema)
979+
registry.register_schema(name, json_schema)
980+
logger.debug(f'Registered schema "{name}"')
981+
982+
943983
def load_prompt(registry: Registry, path: Path, filename: str, prefix: str = '', ns: str = '') -> None:
944984
"""Load a single prompt file and register it in the registry.
945985
@@ -1001,23 +1041,46 @@ async def load_prompt_metadata():
10011041

10021042
# Convert Pydantic model to dict if needed
10031043
if hasattr(prompt_metadata, 'model_dump'):
1004-
prompt_metadata_dict = prompt_metadata.model_dump()
1044+
prompt_metadata_dict = prompt_metadata.model_dump(by_alias=True)
10051045
elif hasattr(prompt_metadata, 'dict'):
1006-
prompt_metadata_dict = prompt_metadata.dict()
1046+
prompt_metadata_dict = prompt_metadata.dict(by_alias=True)
10071047
else:
10081048
# Already a dict
10091049
prompt_metadata_dict = prompt_metadata
10101050

1051+
# Ensure raw metadata is available (critical for lazy schema resolution)
1052+
if hasattr(prompt_metadata, 'raw'):
1053+
prompt_metadata_dict['raw'] = prompt_metadata.raw
1054+
10111055
if variant:
10121056
prompt_metadata_dict['variant'] = variant
10131057

1058+
# Fallback for model if not present (Dotprompt issue)
1059+
if not prompt_metadata_dict.get('model'):
1060+
raw_model = (prompt_metadata_dict.get('raw') or {}).get('model')
1061+
if raw_model:
1062+
prompt_metadata_dict['model'] = raw_model
1063+
10141064
# Clean up null descriptions
10151065
output = prompt_metadata_dict.get('output')
1066+
schema = None
10161067
if output and isinstance(output, dict):
10171068
schema = output.get('schema')
10181069
if schema and isinstance(schema, dict) and schema.get('description') is None:
10191070
schema.pop('description', None)
10201071

1072+
if not schema:
1073+
# Fallback to raw schema name if schema definition is missing
1074+
raw_schema = (prompt_metadata_dict.get('raw') or {}).get('output', {}).get('schema')
1075+
if isinstance(raw_schema, str):
1076+
schema = raw_schema
1077+
# output might be None if it wasn't in parsed config
1078+
if not output:
1079+
output = {'schema': schema}
1080+
prompt_metadata_dict['output'] = output
1081+
elif isinstance(output, dict):
1082+
output['schema'] = schema
1083+
10211084
input_schema = prompt_metadata_dict.get('input')
10221085
if input_schema and isinstance(input_schema, dict):
10231086
schema = input_schema.get('schema')
@@ -1026,6 +1089,7 @@ async def load_prompt_metadata():
10261089

10271090
# Build metadata structure
10281091
metadata = {
1092+
**prompt_metadata_dict,
10291093
**prompt_metadata_dict.get('metadata', {}),
10301094
'type': 'prompt',
10311095
'prompt': {
@@ -1078,6 +1142,7 @@ async def create_prompt_from_file():
10781142
description=metadata.get('description'),
10791143
input_schema=metadata.get('input', {}).get('jsonSchema'),
10801144
output_schema=metadata.get('output', {}).get('jsonSchema'),
1145+
output_constrained=True if metadata.get('output', {}).get('jsonSchema') else None,
10811146
output_format=metadata.get('output', {}).get('format'),
10821147
messages=metadata.get('messages'),
10831148
max_turns=metadata.get('maxTurns'),

py/packages/genkit/src/genkit/core/registry.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,11 @@ def __init__(self):
8484
self._list_actions_resolvers: dict[str, Callable] = {}
8585
self._entries: ActionStore = {}
8686
self._value_by_kind_and_name: dict[str, dict[str, Any]] = {}
87+
self._schemas_by_name: dict[str, dict[str, Any]] = {}
8788
self._lock = threading.RLock()
88-
self.dotprompt = Dotprompt()
89+
90+
# Initialize Dotprompt with schema_resolver to match JS SDK pattern
91+
self.dotprompt = Dotprompt(schema_resolver=lambda name: self.lookup_schema(name) or name)
8992
# TODO: Figure out how to set this.
9093
self.api_stability: str = 'stable'
9194

@@ -334,3 +337,34 @@ def lookup_value(self, kind: str, name: str) -> Any | None:
334337
"""
335338
with self._lock:
336339
return self._value_by_kind_and_name.get(kind, {}).get(name)
340+
341+
def register_schema(self, name: str, schema: dict[str, Any]) -> None:
342+
"""Registers a schema by name.
343+
344+
Schemas registered with this method can be referenced by name in
345+
.prompt files using the `output.schema` field.
346+
347+
Args:
348+
name: The name of the schema.
349+
schema: The schema data (JSON schema format).
350+
351+
Raises:
352+
ValueError: If a schema with the given name is already registered.
353+
"""
354+
with self._lock:
355+
if name in self._schemas_by_name:
356+
raise ValueError(f'Schema "{name}" is already registered')
357+
self._schemas_by_name[name] = schema
358+
logger.debug(f'Registered schema "{name}"')
359+
360+
def lookup_schema(self, name: str) -> dict[str, Any] | None:
361+
"""Looks up a schema by name.
362+
363+
Args:
364+
name: The name of the schema to look up.
365+
366+
Returns:
367+
The schema data if found, None otherwise.
368+
"""
369+
with self._lock:
370+
return self._schemas_by_name.get(name)

py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/model_garden/modelgarden_plugin.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,10 +70,7 @@ def __init__(
7070
)
7171

7272
self.location = (
73-
location
74-
or os.getenv('GOOGLE_CLOUD_LOCATION')
75-
or os.getenv('GOOGLE_CLOUD_REGION')
76-
or const.DEFAULT_REGION
73+
location or os.getenv('GOOGLE_CLOUD_LOCATION') or os.getenv('GOOGLE_CLOUD_REGION') or const.DEFAULT_REGION
7774
)
7875

7976
self.models = models

py/samples/google-genai-context-caching/src/main.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,6 @@
2121
As a result, model is capable to quickly relate to the book's content and answer the follow-up questions.
2222
"""
2323

24-
25-
2624
import httpx
2725
import structlog
2826
from pydantic import BaseModel, Field

0 commit comments

Comments
 (0)