diff --git a/py/bin/sanitize_schema_typing.py b/py/bin/sanitize_schema_typing.py index 6138127e9f..48c6baf919 100644 --- a/py/bin/sanitize_schema_typing.py +++ b/py/bin/sanitize_schema_typing.py @@ -45,7 +45,7 @@ from _ast import AST from datetime import datetime from pathlib import Path -from typing import Type, cast +from typing import Any, Type, cast class ClassTransformer(ast.NodeTransformer): @@ -118,7 +118,18 @@ def has_model_config(self, node: ast.ClassDef) -> ast.Assign | None: return item return None - def visit_ClassDef(self, _node: ast.ClassDef) -> ast.ClassDef: # noqa: N802 + def visit_AnnAssign(self, node: ast.AnnAssign) -> ast.AnnAssign: + """Visit and transform annotated assignment.""" + if isinstance(node.annotation, ast.Name) and node.annotation.id == 'Role': + node.annotation = ast.BinOp( + left=ast.Name(id='Role', ctx=ast.Load()), + op=ast.BitOr(), + right=ast.Name(id='str', ctx=ast.Load()), + ) + self.modified = True + return node + + def visit_ClassDef(self, node: ast.ClassDef) -> Any: """Visit and transform a class definition node. Args: @@ -128,11 +139,16 @@ def visit_ClassDef(self, _node: ast.ClassDef) -> ast.ClassDef: # noqa: N802 The transformed ClassDef node. """ # First apply base class transformations recursively - node = super().generic_visit(_node) + node = cast(ast.ClassDef, super().generic_visit(node)) new_body: list[ast.stmt | ast.Constant | ast.Assign] = [] # Handle Docstrings - if not node.body or not isinstance(node.body[0], ast.Expr) or not isinstance(node.body[0].value, ast.Constant): + if ( + not node.body + or not isinstance(node.body[0], ast.Expr) + or not isinstance(node.body[0].value, ast.Constant) + or not isinstance(node.body[0].value.value, str) + ): # Generate a more descriptive docstring based on class type if self.is_rootmodel_class(node): docstring = f'Root model for {node.name.lower().replace("_", " ")}.' @@ -151,13 +167,21 @@ def visit_ClassDef(self, _node: ast.ClassDef) -> ast.ClassDef: # noqa: N802 # Handle model_config for BaseModel and RootModel existing_model_config_assign = self.has_model_config(node) + existing_model_config_call = None if existing_model_config_assign and isinstance(existing_model_config_assign.value, ast.Call): existing_model_config_call = existing_model_config_assign.value # Determine start index for iterating original body (skip docstring) body_start_index = ( - 1 if (node.body and isinstance(node.body[0], ast.Expr) and isinstance(node.body[0].value, ast.Str)) else 0 + 1 + if ( + node.body + and isinstance(node.body[0], ast.Expr) + and isinstance(node.body[0].value, ast.Constant) + and isinstance(node.body[0].value.value, str) + ) + else 0 ) if self.is_rootmodel_class(node): diff --git a/py/packages/genkit/src/genkit/core/typing.py b/py/packages/genkit/src/genkit/core/typing.py index dce5a4f7f1..68fd541101 100644 --- a/py/packages/genkit/src/genkit/core/typing.py +++ b/py/packages/genkit/src/genkit/core/typing.py @@ -913,7 +913,7 @@ class Message(BaseModel): """Model for message data.""" model_config = ConfigDict(extra='forbid', populate_by_name=True) - role: Role + role: Role | str content: list[Part] metadata: dict[str, Any] | None = None diff --git a/py/pyproject.toml b/py/pyproject.toml index 5fadf9d6ba..5ee43b379b 100644 --- a/py/pyproject.toml +++ b/py/pyproject.toml @@ -199,9 +199,9 @@ convention = "google" docstring-code-format = true docstring-code-line-length = 120 indent-style = "space" +line-ending = "lf" quote-style = "single" skip-magic-trailing-comma = false -line-ending = "lf" [tool.datamodel-codegen] #collapse-root-models = true # Don't use; produces Any as types. diff --git a/py/samples/evaluator-demo/pyproject.toml b/py/samples/evaluator-demo/pyproject.toml index eb049714a7..771c966747 100644 --- a/py/samples/evaluator-demo/pyproject.toml +++ b/py/samples/evaluator-demo/pyproject.toml @@ -15,7 +15,7 @@ # SPDX-License-Identifier: Apache-2.0 [project] -authors = [{ name = "Google" }] +authors = [{ name = "Google" }] dependencies = ["genkit", "pydantic>=2.0.0", "structlog>=24.0.0", "pypdf"] description = "Genkit Python Evaluation Demo" name = "eval-demo" diff --git a/py/samples/multi-server/pyproject.toml b/py/samples/multi-server/pyproject.toml index 364e053b2e..7163fd5d6d 100644 --- a/py/samples/multi-server/pyproject.toml +++ b/py/samples/multi-server/pyproject.toml @@ -16,7 +16,6 @@ [project] authors = [{ name = "Google" }] -license = "Apache-2.0" classifiers = [ "Development Status :: 3 - Alpha", "Environment :: Console", @@ -43,6 +42,7 @@ dependencies = [ "uvicorn>=0.34.0", ] description = "Sample implementation to exercise the Genkit multi server manager." +license = "Apache-2.0" name = "multi-server" readme = "README.md" requires-python = ">=3.10"