Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 30 additions & 2 deletions py/packages/genkit/src/genkit/ai/_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
define_helper,
define_partial,
define_prompt,
define_schema,
lookup_prompt,
)
from genkit.blocks.reranker import (
Expand Down Expand Up @@ -207,6 +208,33 @@ def define_partial(self, name: str, source: str) -> None:
"""
define_partial(self.registry, name, source)

def define_schema(self, name: str, schema: type) -> type:
"""Register a Pydantic schema for use in prompts.
Schemas registered with this method can be referenced by name in
.prompt files using the `output.schema` field.
Args:
name: The name to register the schema under.
schema: The Pydantic model class to register.
Returns:
The schema that was registered (for convenience).
Example:
```python
RecipeSchema = ai.define_schema('Recipe', Recipe)
```
Then in a .prompt file:
```yaml
output:
schema: Recipe
```
"""
define_schema(self.registry, name, schema)
return schema

def tool(self, name: str | None = None, description: str | None = None) -> Callable[[Callable], Callable]:
"""Decorator to register a function as a tool.
Expand Down Expand Up @@ -698,14 +726,14 @@ def define_prompt(
model: str | None = None,
config: GenerationCommonConfig | dict[str, Any] | None = None,
description: str | None = None,
input_schema: type | dict[str, Any] | None = None,
input_schema: type | dict[str, Any] | str | None = None,
system: str | Part | list[Part] | Callable | None = None,
prompt: str | Part | list[Part] | Callable | None = None,
messages: str | list[Message] | Callable | None = None,
output_format: str | None = None,
output_content_type: str | None = None,
output_instructions: bool | str | None = None,
output_schema: type | dict[str, Any] | None = None,
output_schema: type | dict[str, Any] | str | None = None,
output_constrained: bool | None = None,
max_turns: int | None = None,
return_tool_requests: bool | None = None,
Expand Down
4 changes: 3 additions & 1 deletion py/packages/genkit/src/genkit/blocks/formats/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
"""Genkit format package. Provides implementation for various formats like json, jsonl, etc."""

from genkit.blocks.formats.json import JsonFormat
from genkit.blocks.formats.text import TextFormat
from genkit.blocks.formats.types import FormatDef, Formatter, FormatterConfig


Expand All @@ -26,13 +27,14 @@ def package_name() -> str:
return 'genkit.blocks.formats'


built_in_formats = [JsonFormat()]
built_in_formats = [JsonFormat(), TextFormat()]


__all__ = [
FormatDef.__name__,
Formatter.__name__,
FormatterConfig.__name__,
JsonFormat.__name__,
TextFormat.__name__,
package_name.__name__,
]
69 changes: 69 additions & 0 deletions py/packages/genkit/src/genkit/blocks/formats/text.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0

"""Implementation of text output format."""

from typing import Any

from genkit.blocks.formats.types import FormatDef, Formatter, FormatterConfig
from genkit.blocks.model import (
GenerateResponseWrapper,
MessageWrapper,
)


class TextFormat(FormatDef):
"""Defines a text format for use with AI models.

This class provides functionality for parsing and formatting text data
to interact with AI models.
"""

def __init__(self):
"""Initializes a TextFormat instance."""
super().__init__(
'text',
FormatterConfig(
format='text',
content_type='text/plain',
constrained=None,
default_instructions=False,
),
)

def handle(self, schema: dict[str, Any] | None) -> Formatter:
"""Creates a Formatter for handling text data.

Args:
schema: Optional schema (ignored for text).

Returns:
A Formatter instance configured for text handling.
"""

def message_parser(msg: MessageWrapper):
"""Extracts text from a Message object."""
return msg.text

def chunk_parser(chunk: GenerateResponseWrapper):
"""Extracts text from a GenerateResponseWrapper object."""
return chunk.accumulated_text

return Formatter(
chunk_parser=chunk_parser,
message_parser=message_parser,
instructions=None,
)
93 changes: 79 additions & 14 deletions py/packages/genkit/src/genkit/blocks/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,14 +89,14 @@ class PromptConfig(BaseModel):
model: str | None = None
config: GenerationCommonConfig | dict[str, Any] | None = None
description: str | None = None
input_schema: type | dict[str, Any] | None = None
input_schema: type | dict[str, Any] | str | None = None
system: str | Part | list[Part] | Callable | None = None
prompt: str | Part | list[Part] | Callable | None = None
messages: str | list[Message] | Callable | None = None
output_format: str | None = None
output_content_type: str | None = None
output_instructions: bool | str | None = None
output_schema: type | dict[str, Any] | None = None
output_schema: type | dict[str, Any] | str | None = None
output_constrained: bool | None = None
max_turns: int | None = None
return_tool_requests: bool | None = None
Expand Down Expand Up @@ -148,14 +148,14 @@ def __init__(
model: The model to use for generation.
config: The generation configuration.
description: A description of the prompt.
input_schema: The input schema for the prompt.
system: The system message for the prompt.
prompt: The user prompt.
messages: A list of messages to include in the prompt.
output_format: The output format.
output_content_type: The output content type.
input_schema: type | dict[str, Any] | str | None = None,
system: str | Part | list[Part] | Callable | None = None,
prompt: str | Part | list[Part] | Callable | None = None,
messages: str | list[Message] | Callable | None = None,
output_format: str | None = None,
output_content_type: str | None = None,
output_instructions: Instructions for formatting the output.
output_schema: The output schema.
output_schema: type | dict[str, Any] | str | None = None,
output_constrained: Whether the output should be constrained to the output schema.
max_turns: The maximum number of turns in a conversation.
return_tool_requests: Whether to return tool requests.
Expand Down Expand Up @@ -387,14 +387,14 @@ def define_prompt(
model: str | None = None,
config: GenerationCommonConfig | dict[str, Any] | None = None,
description: str | None = None,
input_schema: type | dict[str, Any] | None = None,
input_schema: type | dict[str, Any] | str | None = None,
system: str | Part | list[Part] | Callable | None = None,
prompt: str | Part | list[Part] | Callable | None = None,
messages: str | list[Message] | Callable | None = None,
output_format: str | None = None,
output_content_type: str | None = None,
output_instructions: bool | str | None = None,
output_schema: type | dict[str, Any] | None = None,
output_schema: type | dict[str, Any] | str | None = None,
output_constrained: bool | None = None,
max_turns: int | None = None,
return_tool_requests: bool | None = None,
Expand Down Expand Up @@ -541,7 +541,18 @@ async def to_generate_action_options(registry: Registry, options: PromptConfig)
if options.output_instructions is not None:
output.instructions = options.output_instructions
if options.output_schema:
output.json_schema = to_json_schema(options.output_schema)
if isinstance(options.output_schema, str):
resolved_schema = registry.lookup_schema(options.output_schema)
if resolved_schema:
output.json_schema = resolved_schema
elif options.output_constrained:
# If we have a schema name but can't resolve it, and constrained is True,
# we should probably error or warn. But for now, we might pass None or
# try one last look up?
# Actually, lookup_schema handles it. If None, we can't do much.
pass
else:
output.json_schema = to_json_schema(options.output_schema)
if options.output_constrained is not None:
output.constrained = options.output_constrained

Expand Down Expand Up @@ -940,6 +951,35 @@ def define_helper(registry: Registry, name: str, fn: Callable) -> None:
logger.debug(f'Registered Dotprompt helper "{name}"')


def define_schema(registry: Registry, name: str, schema: type) -> None:
"""Register a Pydantic schema for use in prompts.

Schemas registered with this function can be referenced by name in
.prompt files using the `output.schema` field.

Args:
registry: The registry to register the schema in.
name: The name of the schema.
schema: The Pydantic model class to register.

Example:
```python
from genkit.blocks.prompt import define_schema

define_schema(registry, 'Recipe', Recipe)
```

Then in a .prompt file:
```yaml
output:
schema: Recipe
```
"""
json_schema = to_json_schema(schema)
registry.register_schema(name, json_schema)
logger.debug(f'Registered schema "{name}"')


def load_prompt(registry: Registry, path: Path, filename: str, prefix: str = '', ns: str = '') -> None:
"""Load a single prompt file and register it in the registry.

Expand Down Expand Up @@ -1001,23 +1041,46 @@ async def load_prompt_metadata():

# Convert Pydantic model to dict if needed
if hasattr(prompt_metadata, 'model_dump'):
prompt_metadata_dict = prompt_metadata.model_dump()
prompt_metadata_dict = prompt_metadata.model_dump(by_alias=True)
elif hasattr(prompt_metadata, 'dict'):
prompt_metadata_dict = prompt_metadata.dict()
prompt_metadata_dict = prompt_metadata.dict(by_alias=True)
else:
# Already a dict
prompt_metadata_dict = prompt_metadata

# Ensure raw metadata is available (critical for lazy schema resolution)
if hasattr(prompt_metadata, 'raw'):
prompt_metadata_dict['raw'] = prompt_metadata.raw

if variant:
prompt_metadata_dict['variant'] = variant

# Fallback for model if not present (Dotprompt issue)
if not prompt_metadata_dict.get('model'):
raw_model = (prompt_metadata_dict.get('raw') or {}).get('model')
if raw_model:
prompt_metadata_dict['model'] = raw_model

# Clean up null descriptions
output = prompt_metadata_dict.get('output')
schema = None
if output and isinstance(output, dict):
schema = output.get('schema')
if schema and isinstance(schema, dict) and schema.get('description') is None:
schema.pop('description', None)

if not schema:
# Fallback to raw schema name if schema definition is missing
raw_schema = (prompt_metadata_dict.get('raw') or {}).get('output', {}).get('schema')
if isinstance(raw_schema, str):
schema = raw_schema
# output might be None if it wasn't in parsed config
if not output:
output = {'schema': schema}
prompt_metadata_dict['output'] = output
elif isinstance(output, dict):
output['schema'] = schema

input_schema = prompt_metadata_dict.get('input')
if input_schema and isinstance(input_schema, dict):
schema = input_schema.get('schema')
Expand All @@ -1026,6 +1089,7 @@ async def load_prompt_metadata():

# Build metadata structure
metadata = {
**prompt_metadata_dict,
**prompt_metadata_dict.get('metadata', {}),
'type': 'prompt',
'prompt': {
Expand Down Expand Up @@ -1078,6 +1142,7 @@ async def create_prompt_from_file():
description=metadata.get('description'),
input_schema=metadata.get('input', {}).get('jsonSchema'),
output_schema=metadata.get('output', {}).get('jsonSchema'),
output_constrained=True if metadata.get('output', {}).get('jsonSchema') else None,
output_format=metadata.get('output', {}).get('format'),
messages=metadata.get('messages'),
max_turns=metadata.get('maxTurns'),
Expand Down
36 changes: 35 additions & 1 deletion py/packages/genkit/src/genkit/core/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,11 @@ def __init__(self):
self._list_actions_resolvers: dict[str, Callable] = {}
self._entries: ActionStore = {}
self._value_by_kind_and_name: dict[str, dict[str, Any]] = {}
self._schemas_by_name: dict[str, dict[str, Any]] = {}
self._lock = threading.RLock()
self.dotprompt = Dotprompt()

# Initialize Dotprompt with schema_resolver to match JS SDK pattern
self.dotprompt = Dotprompt(schema_resolver=lambda name: self.lookup_schema(name) or name)
# TODO: Figure out how to set this.
self.api_stability: str = 'stable'

Expand Down Expand Up @@ -334,3 +337,34 @@ def lookup_value(self, kind: str, name: str) -> Any | None:
"""
with self._lock:
return self._value_by_kind_and_name.get(kind, {}).get(name)

def register_schema(self, name: str, schema: dict[str, Any]) -> None:
"""Registers a schema by name.

Schemas registered with this method can be referenced by name in
.prompt files using the `output.schema` field.

Args:
name: The name of the schema.
schema: The schema data (JSON schema format).

Raises:
ValueError: If a schema with the given name is already registered.
"""
with self._lock:
if name in self._schemas_by_name:
raise ValueError(f'Schema "{name}" is already registered')
self._schemas_by_name[name] = schema
logger.debug(f'Registered schema "{name}"')

def lookup_schema(self, name: str) -> dict[str, Any] | None:
"""Looks up a schema by name.

Args:
name: The name of the schema to look up.

Returns:
The schema data if found, None otherwise.
"""
with self._lock:
return self._schemas_by_name.get(name)
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,7 @@ def __init__(
)

self.location = (
location
or os.getenv('GOOGLE_CLOUD_LOCATION')
or os.getenv('GOOGLE_CLOUD_REGION')
or const.DEFAULT_REGION
location or os.getenv('GOOGLE_CLOUD_LOCATION') or os.getenv('GOOGLE_CLOUD_REGION') or const.DEFAULT_REGION
)

self.models = models
Expand Down
2 changes: 0 additions & 2 deletions py/samples/google-genai-context-caching/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@
As a result, model is capable to quickly relate to the book's content and answer the follow-up questions.
"""



import httpx
import structlog
from pydantic import BaseModel, Field
Expand Down
Loading
Loading