Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ jobs:
runs-on: "${{ matrix.os }}"
strategy:
matrix:
python-version: [3.7, 3.8, 3.9, "3.10"]
python-version: ["3.9", "3.10", "3.11", "3.12"]
os: [windows-latest, ubuntu-latest, macos-latest]
env:
OS: ${{ matrix.os }}
Expand Down Expand Up @@ -81,7 +81,7 @@ jobs:
id-token: write
strategy:
matrix:
python-version: [3.7]
python-version: ["3.12"]
os: [ubuntu-latest]

steps:
Expand All @@ -102,3 +102,4 @@ jobs:

- name: Publish package distributions to PyPI
uses: pypa/gh-action-pypi-publish@release/v1

7 changes: 4 additions & 3 deletions rpcpy/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from rpcpy.openapi import (
ValidationError,
create_model,
create_root_model,
is_typed_dict_type,
parse_typed_dict,
set_type_model,
Expand Down Expand Up @@ -162,9 +163,9 @@ def get_openapi_docs(self) -> dict:
elif return_annotation is None:
resp_model = create_model(callback.__name__ + "-return")
else:
resp_model = create_model(
callback.__name__ + "-return",
__root__=(return_annotation, ...),
resp_model = create_root_model(
model_name=callback.__name__ + "-return",
return_annotation=return_annotation,
)
_schema = copy.deepcopy(resp_model.schema())
definitions.update(_schema.pop("definitions", {}))
Expand Down
2 changes: 1 addition & 1 deletion rpcpy/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def __init__(self) -> None:
self.message: ServerSentEvent = {}

def feed(self, line: str) -> ServerSentEvent | None:
if line == "\n": # event split line
if not line or line == "\n": # event split line
event = self.message
self.message = {}
return event
Expand Down
27 changes: 27 additions & 0 deletions rpcpy/openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,12 @@
Callable = typing.TypeVar("Callable", bound=typing.Callable)

try:
from pydantic import VERSION as PYDANTIC_VERSION
from pydantic import BaseModel, ValidationError, create_model
from pydantic import validate_arguments as pydantic_validate_arguments

IS_PYDANTIC_V2 = int(PYDANTIC_VERSION.split(".")[0]) >= 2

# visit this issue
# https://github.com/samuelcolvin/pydantic/issues/1205
def validate_arguments(function: Callable) -> Callable:
Expand All @@ -41,6 +44,8 @@ def change_exception(*args, **kwargs):

except ImportError:

IS_PYDANTIC_V2 = False

def create_model(*args, **kwargs): # type: ignore
raise NotImplementedError("Need install `pydantic` from pypi.")

Expand All @@ -53,8 +58,12 @@ class ValidationError(Exception): # type: ignore
"""

if typing.TYPE_CHECKING:
from pydantic import VERSION as PYDANTIC_VERSION
from pydantic import BaseModel

if IS_PYDANTIC_V2:
from pydantic import RootModel


def set_type_model(func: Callable) -> Callable:
"""
Expand Down Expand Up @@ -111,6 +120,24 @@ def parse_typed_dict(typed_dict) -> typing.Type[BaseModel]:
return create_model(typed_dict.__name__, **annotations) # type: ignore


def create_root_model(model_name: str, return_annotation: type) -> typing.Type[BaseModel]:
"""
Create a Pydantic model with a single root field for the return type.

This function handles both Pydantic v1 and v2 styles of model creation.
"""
if IS_PYDANTIC_V2:
# Dynamically create a subclass of RootModel
return type(
model_name,
(RootModel,),
{"__annotations__": {"root": return_annotation}},
)
else:
# Pydantic v1 style using create_model with __root__
return create_model(model_name, __root__=(return_annotation, ...))


TEMPLATE = """<!DOCTYPE html>
<html>

Expand Down
3 changes: 0 additions & 3 deletions tests/test_client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import asyncio
import time
from typing import AsyncGenerator, Generator

import httpx
Expand All @@ -25,7 +24,6 @@ def sayhi(name: str) -> str:
@app.register
def yield_data(max_num: int) -> Generator[int, None, None]:
for i in range(max_num):
time.sleep(1)
yield i

@app.register
Expand Down Expand Up @@ -55,7 +53,6 @@ async def sayhi(name: str) -> str:
@app.register
async def yield_data(max_num: int) -> AsyncGenerator[int, None]:
for i in range(max_num):
await asyncio.sleep(1)
yield i

@app.register
Expand Down
Loading