Skip to content

Commit c840057

Browse files
committed
fix(py/genkit): fix type errors reported by ty in typing sanitizer
CHANGELOG: - [ ] Fix errors in py/bin/sanitize_schema_typing.py and typing.py - [ ] Format some pyproject.toml files
1 parent 7b91230 commit c840057

File tree

5 files changed

+39
-16
lines changed

5 files changed

+39
-16
lines changed

py/bin/sanitize_schema_typing.py

Lines changed: 35 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545
from _ast import AST
4646
from datetime import datetime
4747
from pathlib import Path
48-
from typing import Type, cast
48+
from typing import Any, Type, cast
4949

5050

5151
class ClassTransformer(ast.NodeTransformer):
@@ -118,21 +118,36 @@ def has_model_config(self, node: ast.ClassDef) -> ast.Assign | None:
118118
return item
119119
return None
120120

121-
def visit_ClassDef(self, _node: ast.ClassDef) -> ast.ClassDef: # noqa: N802
122-
"""Visit and transform a class definition node.
123-
124-
Args:
125-
node: The ClassDef AST node to transform.
121+
def visit_AnnAssign(self, node: ast.AnnAssign) -> ast.AnnAssign:
122+
"""Visit and transform annotated assignment."""
123+
if isinstance(node.annotation, ast.Name) and node.annotation.id == 'Role':
124+
node.annotation = ast.BinOp(
125+
left=ast.Name(id='Role', ctx=ast.Load()),
126+
op=ast.BitOr(),
127+
right=ast.Name(id='str', ctx=ast.Load()),
128+
)
129+
self.modified = True
130+
return node
126131

127-
Returns:
128-
The transformed ClassDef node.
129-
"""
132+
def visit_ClassDef(self, node: ast.ClassDef) -> Any:
133+
# Visit and transform a class definition node.
134+
#
135+
# Args:
136+
# node: The ClassDef AST node to transform.
137+
#
138+
# Returns:
139+
# The transformed ClassDef node.
130140
# First apply base class transformations recursively
131-
node = super().generic_visit(_node)
141+
node = cast(ast.ClassDef, super().generic_visit(node))
132142
new_body: list[ast.stmt | ast.Constant | ast.Assign] = []
133143

134144
# Handle Docstrings
135-
if not node.body or not isinstance(node.body[0], ast.Expr) or not isinstance(node.body[0].value, ast.Constant):
145+
if (
146+
not node.body
147+
or not isinstance(node.body[0], ast.Expr)
148+
or not isinstance(node.body[0].value, ast.Constant)
149+
or not isinstance(node.body[0].value.value, str)
150+
):
136151
# Generate a more descriptive docstring based on class type
137152
if self.is_rootmodel_class(node):
138153
docstring = f'Root model for {node.name.lower().replace("_", " ")}.'
@@ -151,13 +166,21 @@ def visit_ClassDef(self, _node: ast.ClassDef) -> ast.ClassDef: # noqa: N802
151166

152167
# Handle model_config for BaseModel and RootModel
153168
existing_model_config_assign = self.has_model_config(node)
169+
154170
existing_model_config_call = None
155171
if existing_model_config_assign and isinstance(existing_model_config_assign.value, ast.Call):
156172
existing_model_config_call = existing_model_config_assign.value
157173

158174
# Determine start index for iterating original body (skip docstring)
159175
body_start_index = (
160-
1 if (node.body and isinstance(node.body[0], ast.Expr) and isinstance(node.body[0].value, ast.Str)) else 0
176+
1
177+
if (
178+
node.body
179+
and isinstance(node.body[0], ast.Expr)
180+
and isinstance(node.body[0].value, ast.Constant)
181+
and isinstance(node.body[0].value.value, str)
182+
)
183+
else 0
161184
)
162185

163186
if self.is_rootmodel_class(node):

py/packages/genkit/src/genkit/core/typing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -913,7 +913,7 @@ class Message(BaseModel):
913913
"""Model for message data."""
914914

915915
model_config = ConfigDict(extra='forbid', populate_by_name=True)
916-
role: Role
916+
role: Role | str
917917
content: list[Part]
918918
metadata: dict[str, Any] | None = None
919919

py/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,9 +199,9 @@ convention = "google"
199199
docstring-code-format = true
200200
docstring-code-line-length = 120
201201
indent-style = "space"
202+
line-ending = "lf"
202203
quote-style = "single"
203204
skip-magic-trailing-comma = false
204-
line-ending = "lf"
205205

206206
[tool.datamodel-codegen]
207207
#collapse-root-models = true # Don't use; produces Any as types.

py/samples/evaluator-demo/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
# SPDX-License-Identifier: Apache-2.0
1616

1717
[project]
18-
authors = [{ name = "Google" }]
18+
authors = [{ name = "Google" }]
1919
dependencies = ["genkit", "pydantic>=2.0.0", "structlog>=24.0.0", "pypdf"]
2020
description = "Genkit Python Evaluation Demo"
2121
name = "eval-demo"

py/samples/multi-server/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
[project]
1818
authors = [{ name = "Google" }]
19-
license = "Apache-2.0"
2019
classifiers = [
2120
"Development Status :: 3 - Alpha",
2221
"Environment :: Console",
@@ -43,6 +42,7 @@ dependencies = [
4342
"uvicorn>=0.34.0",
4443
]
4544
description = "Sample implementation to exercise the Genkit multi server manager."
45+
license = "Apache-2.0"
4646
name = "multi-server"
4747
readme = "README.md"
4848
requires-python = ">=3.10"

0 commit comments

Comments
 (0)