Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ jobs:
runs-on: "${{ matrix.os }}"
strategy:
matrix:
python-version: [3.7, 3.8, 3.9, "3.10"]
python-version: [3.9, 3.10, 3.11, 3.12]
os: [windows-latest, ubuntu-latest, macos-latest]
env:
OS: ${{ matrix.os }}
Expand Down Expand Up @@ -102,3 +102,4 @@ jobs:

- name: Publish package distributions to PyPI
uses: pypa/gh-action-pypi-publish@release/v1

7 changes: 4 additions & 3 deletions rpcpy/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from rpcpy.openapi import (
ValidationError,
create_model,
create_root_model,
is_typed_dict_type,
parse_typed_dict,
set_type_model,
Expand Down Expand Up @@ -162,9 +163,9 @@ def get_openapi_docs(self) -> dict:
elif return_annotation is None:
resp_model = create_model(callback.__name__ + "-return")
else:
resp_model = create_model(
callback.__name__ + "-return",
__root__=(return_annotation, ...),
resp_model = create_root_model(
model_name=callback.__name__ + "-return",
return_annotation=return_annotation,
)
_schema = copy.deepcopy(resp_model.schema())
definitions.update(_schema.pop("definitions", {}))
Expand Down
2 changes: 1 addition & 1 deletion rpcpy/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def __init__(self) -> None:
self.message: ServerSentEvent = {}

def feed(self, line: str) -> ServerSentEvent | None:
if line == "\n": # event split line
if not line or line == "\n": # event split line
event = self.message
self.message = {}
return event
Expand Down
27 changes: 27 additions & 0 deletions rpcpy/openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,12 @@
Callable = typing.TypeVar("Callable", bound=typing.Callable)

try:
from pydantic import VERSION as PYDANTIC_VERSION
from pydantic import BaseModel, ValidationError, create_model
from pydantic import validate_arguments as pydantic_validate_arguments

IS_PYDANTIC_V2 = int(PYDANTIC_VERSION.split(".")[0]) >= 2

# visit this issue
# https://github.com/samuelcolvin/pydantic/issues/1205
def validate_arguments(function: Callable) -> Callable:
Expand All @@ -41,6 +44,8 @@ def change_exception(*args, **kwargs):

except ImportError:

IS_PYDANTIC_V2 = False

def create_model(*args, **kwargs): # type: ignore
raise NotImplementedError("Need install `pydantic` from pypi.")

Expand All @@ -53,8 +58,12 @@ class ValidationError(Exception): # type: ignore
"""

if typing.TYPE_CHECKING:
from pydantic import VERSION as PYDANTIC_VERSION
from pydantic import BaseModel

if IS_PYDANTIC_V2:
from pydantic import RootModel


def set_type_model(func: Callable) -> Callable:
"""
Expand Down Expand Up @@ -111,6 +120,24 @@ def parse_typed_dict(typed_dict) -> typing.Type[BaseModel]:
return create_model(typed_dict.__name__, **annotations) # type: ignore


def create_root_model(model_name: str, return_annotation: type) -> typing.Type[BaseModel]:
"""
Create a Pydantic model with a single root field for the return type.

This function handles both Pydantic v1 and v2 styles of model creation.
"""
if IS_PYDANTIC_V2:
# Dynamically create a subclass of RootModel
return type(
model_name,
(RootModel,),
{"__annotations__": {"root": return_annotation}},
)
else:
# Pydantic v1 style using create_model with __root__
return create_model(model_name, __root__=(return_annotation, ...))


TEMPLATE = """<!DOCTYPE html>
<html>

Expand Down
2 changes: 0 additions & 2 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ def sayhi(name: str) -> str:
@app.register
def yield_data(max_num: int) -> Generator[int, None, None]:
for i in range(max_num):
time.sleep(1)
yield i

@app.register
Expand Down Expand Up @@ -55,7 +54,6 @@ async def sayhi(name: str) -> str:
@app.register
async def yield_data(max_num: int) -> AsyncGenerator[int, None]:
for i in range(max_num):
await asyncio.sleep(1)
yield i

@app.register
Expand Down
Loading