Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 29 additions & 5 deletions py/bin/sanitize_schema_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
from _ast import AST
from datetime import datetime
from pathlib import Path
from typing import Type, cast
from typing import Any, Type, cast


class ClassTransformer(ast.NodeTransformer):
Expand Down Expand Up @@ -118,7 +118,18 @@ def has_model_config(self, node: ast.ClassDef) -> ast.Assign | None:
return item
return None

def visit_ClassDef(self, _node: ast.ClassDef) -> ast.ClassDef: # noqa: N802
def visit_AnnAssign(self, node: ast.AnnAssign) -> ast.AnnAssign:
"""Visit and transform annotated assignment."""
if isinstance(node.annotation, ast.Name) and node.annotation.id == 'Role':
node.annotation = ast.BinOp(
left=ast.Name(id='Role', ctx=ast.Load()),
op=ast.BitOr(),
right=ast.Name(id='str', ctx=ast.Load()),
)
self.modified = True
return node

def visit_ClassDef(self, node: ast.ClassDef) -> Any:
"""Visit and transform a class definition node.

Args:
Expand All @@ -128,11 +139,16 @@ def visit_ClassDef(self, _node: ast.ClassDef) -> ast.ClassDef: # noqa: N802
The transformed ClassDef node.
"""
# First apply base class transformations recursively
node = super().generic_visit(_node)
node = cast(ast.ClassDef, super().generic_visit(node))
new_body: list[ast.stmt | ast.Constant | ast.Assign] = []

# Handle Docstrings
if not node.body or not isinstance(node.body[0], ast.Expr) or not isinstance(node.body[0].value, ast.Constant):
if (
not node.body
or not isinstance(node.body[0], ast.Expr)
or not isinstance(node.body[0].value, ast.Constant)
or not isinstance(node.body[0].value.value, str)
):
# Generate a more descriptive docstring based on class type
if self.is_rootmodel_class(node):
docstring = f'Root model for {node.name.lower().replace("_", " ")}.'
Expand All @@ -151,13 +167,21 @@ def visit_ClassDef(self, _node: ast.ClassDef) -> ast.ClassDef: # noqa: N802

# Handle model_config for BaseModel and RootModel
existing_model_config_assign = self.has_model_config(node)

existing_model_config_call = None
if existing_model_config_assign and isinstance(existing_model_config_assign.value, ast.Call):
existing_model_config_call = existing_model_config_assign.value

# Determine start index for iterating original body (skip docstring)
body_start_index = (
1 if (node.body and isinstance(node.body[0], ast.Expr) and isinstance(node.body[0].value, ast.Str)) else 0
1
if (
node.body
and isinstance(node.body[0], ast.Expr)
and isinstance(node.body[0].value, ast.Constant)
and isinstance(node.body[0].value.value, str)
)
else 0
)

if self.is_rootmodel_class(node):
Expand Down
2 changes: 1 addition & 1 deletion py/packages/genkit/src/genkit/core/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -913,7 +913,7 @@ class Message(BaseModel):
"""Model for message data."""

model_config = ConfigDict(extra='forbid', populate_by_name=True)
role: Role
role: Role | str
content: list[Part]
metadata: dict[str, Any] | None = None

Expand Down
2 changes: 1 addition & 1 deletion py/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -199,9 +199,9 @@ convention = "google"
docstring-code-format = true
docstring-code-line-length = 120
indent-style = "space"
line-ending = "lf"
quote-style = "single"
skip-magic-trailing-comma = false
line-ending = "lf"

[tool.datamodel-codegen]
#collapse-root-models = true # Don't use; produces Any as types.
Expand Down
2 changes: 1 addition & 1 deletion py/samples/evaluator-demo/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# SPDX-License-Identifier: Apache-2.0

[project]
authors = [{ name = "Google" }]
authors = [{ name = "Google" }]
dependencies = ["genkit", "pydantic>=2.0.0", "structlog>=24.0.0", "pypdf"]
description = "Genkit Python Evaluation Demo"
name = "eval-demo"
Expand Down
2 changes: 1 addition & 1 deletion py/samples/multi-server/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

[project]
authors = [{ name = "Google" }]
license = "Apache-2.0"
classifiers = [
"Development Status :: 3 - Alpha",
"Environment :: Console",
Expand All @@ -43,6 +42,7 @@ dependencies = [
"uvicorn>=0.34.0",
]
description = "Sample implementation to exercise the Genkit multi server manager."
license = "Apache-2.0"
name = "multi-server"
readme = "README.md"
requires-python = ">=3.10"
Expand Down
Loading