Skip to content

Commit f182bda

Browse files
authored
Merge pull request #95 from jymchng/fix-#94-fastapi-compat
Fix #94 fastapi compat
2 parents 7277637 + a5b3f44 commit f182bda

File tree

9 files changed

+251
-34
lines changed

9 files changed

+251
-34
lines changed

.github/workflows/tests.yaml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,14 @@ jobs:
100100
run: |
101101
uv run python -m pytest tests/ -vv -s
102102
103+
104+
- name: Run fastapi compatibility tests
105+
if: steps.check_test_files.outputs.files_exists == 'true'
106+
env:
107+
PYTHONWARNINGS: ignore
108+
run: |
109+
uv run nox -s test-compat-fastapi
110+
103111
#----------------------------------------------
104112
# make sure docs build
105113
#----------------------------------------------

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ https://github.com/jymchng/fastapi-shield
2121
<hr style="border: none; border-top: 1px solid #ccc; margin: 1em 0;">
2222

2323
### Compatibility and Version
24+
<img src="https://github.com/jymchng/fastapi-shield/actions/workflows/tests.yaml/badge.svg">
25+
<img src="https://img.shields.io/badge/dynamic/toml?url=https%3A%2F%2Fraw.githubusercontent.com%2Fjymchng%2Ffastapi-shield%2Frefs%2Fheads%2Fmain%2Fpyproject.toml&query=%24.project.dependencies%5B0%5D&label=compat&labelColor=green">
2426
<img src="https://img.shields.io/pypi/pyversions/fastapi-shield?color=green" alt="Python compat">
2527
<a href="https://pypi.python.org/pypi/fastapi-shield"><img src="https://img.shields.io/pypi/v/fastapi-shield.svg" alt="PyPi"></a>
2628

noxfile.py

Lines changed: 195 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33
import shutil
44
from functools import wraps
55
import pathlib
6+
import urllib.request
7+
import json
8+
import re
69

710
import nox
811
import nox.command as nox_command
@@ -154,6 +157,126 @@ def session(
154157
)
155158

156159

160+
# --- FastAPI compatibility matrix helpers ---
161+
PYPI_JSON_URL_TEMPLATE = "https://pypi.org/pypi/{package}/json"
162+
163+
164+
def _parse_strict_version_tuple(ver_str: str):
165+
"""Parse a strict semantic version 'X.Y.Z' into a tuple of ints.
166+
167+
Returns None if version doesn't match strict pattern (filters out pre-releases).
168+
"""
169+
m = re.match(r"^(\d+)\.(\d+)\.(\d+)$", ver_str)
170+
if not m:
171+
return None
172+
return int(m.group(1)), int(m.group(2)), int(m.group(3))
173+
174+
175+
def _version_tuple_to_str(t):
176+
return f"{t[0]}.{t[1]}.{t[2]}"
177+
178+
179+
def _cmp_major_minor(a, b):
180+
"""Compare (major, minor) tuples only."""
181+
if a[0] != b[0]:
182+
return a[0] - b[0]
183+
return a[1] - b[1]
184+
185+
186+
def _get_min_supported_version_from_pyproject(
187+
package_name: str, manifest: dict = PROJECT_MANIFEST
188+
):
189+
"""Extract minimum supported version from pyproject for given package.
190+
191+
Supports entries like 'fastapi>=0.100.1' and 'fastapi[standard]>=0.100.1'.
192+
Returns a version tuple (major, minor, patch) or None if not found.
193+
"""
194+
deps = manifest.get("project", {}).get("dependencies", [])
195+
patterns = [
196+
rf"^{re.escape(package_name)}>=([0-9]+\.[0-9]+\.[0-9]+)$",
197+
rf"^{re.escape(package_name)}\[[^\]]+\]>=([0-9]+\.[0-9]+\.[0-9]+)$",
198+
]
199+
for dep in deps:
200+
for pat in patterns:
201+
m = re.match(pat, dep)
202+
if m:
203+
vt = _parse_strict_version_tuple(m.group(1))
204+
if vt:
205+
return vt
206+
return None
207+
208+
209+
def _fetch_pypi_latest_and_releases(package_name: str):
210+
"""Fetch latest version and releases list from PyPI JSON.
211+
212+
Returns (latest_version_tuple, releases_dict) where releases_dict maps
213+
(major, minor) -> max patch available for that minor.
214+
"""
215+
url = PYPI_JSON_URL_TEMPLATE.format(package=package_name)
216+
try:
217+
with urllib.request.urlopen(url) as resp:
218+
data = json.loads(resp.read().decode("utf-8"))
219+
except Exception:
220+
return None, {}
221+
222+
latest_str = data.get("info", {}).get("version")
223+
latest_tuple = _parse_strict_version_tuple(latest_str) if latest_str else None
224+
225+
releases = data.get("releases", {})
226+
minor_to_max_patch = {}
227+
for ver_str in releases.keys():
228+
vt = _parse_strict_version_tuple(ver_str)
229+
if not vt:
230+
# skip pre-release or non-strict versions
231+
continue
232+
major, minor, patch = vt
233+
key = (major, minor)
234+
prev = minor_to_max_patch.get(key)
235+
if prev is None or patch > prev:
236+
minor_to_max_patch[key] = patch
237+
238+
return latest_tuple, minor_to_max_patch
239+
240+
241+
def _build_minor_matrix(min_vt, latest_vt, minor_to_max_patch):
242+
"""Build a list of version strings representing the highest patch in each minor
243+
from min_vt to latest_vt inclusive. Only includes minors that exist in releases.
244+
"""
245+
if not min_vt or not latest_vt:
246+
return []
247+
result = []
248+
# Collect and sort available minor keys
249+
available_minors = sorted(minor_to_max_patch.keys(), key=lambda k: (k[0], k[1]))
250+
for major, minor in available_minors:
251+
# range filter: min <= (major, minor) <= latest
252+
if _cmp_major_minor((major, minor), (min_vt[0], min_vt[1])) < 0:
253+
continue
254+
if _cmp_major_minor((major, minor), (latest_vt[0], latest_vt[1])) > 0:
255+
continue
256+
patch = minor_to_max_patch[(major, minor)]
257+
result.append(_version_tuple_to_str((major, minor, patch)))
258+
return result
259+
260+
261+
def _compute_fastapi_minor_matrix():
262+
package = "fastapi"
263+
min_vt = _get_min_supported_version_from_pyproject(package)
264+
latest_vt, minor_to_max_patch = _fetch_pypi_latest_and_releases(package)
265+
matrix = _build_minor_matrix(min_vt, latest_vt, minor_to_max_patch)
266+
# Fallbacks if network fails or parsing issues
267+
if not matrix:
268+
vals = []
269+
if min_vt:
270+
vals.append(_version_tuple_to_str(min_vt))
271+
if latest_vt and latest_vt != min_vt:
272+
vals.append(_version_tuple_to_str(latest_vt))
273+
matrix = vals or ["0.100.1"]
274+
return matrix
275+
276+
277+
FASTAPI_MINOR_MATRIX = _compute_fastapi_minor_matrix()
278+
279+
157280
def uv_install_group_dependencies(session: Session, dependency_group: str):
158281
pyproject = nox.project.load_toml(MANIFEST_FILENAME)
159282
dependencies = nox.project.dependency_groups(pyproject, dependency_group)
@@ -256,6 +379,42 @@ def test(session: AlteredSession):
256379
session.run(*command)
257380

258381

382+
@session(
383+
dependency_group=None,
384+
default_posargs=[TEST_DIR, "-s", "-vv", "-n", "auto", "--dist", "worksteal"],
385+
reuse_venv=False,
386+
)
387+
@nox.parametrize("fastapi_version", FASTAPI_MINOR_MATRIX)
388+
def test_compat_fastapi(session: AlteredSession, fastapi_version: str):
389+
"""Run tests against a matrix of FastAPI minor versions.
390+
391+
The matrix is computed from pyproject's minimum supported version and
392+
PyPI's latest release, selecting the highest patch per minor.
393+
"""
394+
session.log(f"Testing compatibility with FastAPI versions: {FASTAPI_MINOR_MATRIX}")
395+
# Pin FastAPI (and extras) to the target minor's highest patch before running tests.
396+
# Install dev dependencies excluding FastAPI to avoid overriding the pinned version.
397+
pyproject = load_toml(MANIFEST_FILENAME)
398+
dev_deps = nox.project.dependency_groups(pyproject, "dev")
399+
filtered_dev_deps = [d for d in dev_deps if not d.startswith("fastapi")]
400+
if filtered_dev_deps:
401+
session.install(*filtered_dev_deps)
402+
# Pin FastAPI (and extras) to the target minor's highest patch before running tests.
403+
session.install(f"fastapi[standard]=={fastapi_version}")
404+
with alter_session(session, dependency_group=None) as session:
405+
session.install(f".")
406+
session.run(
407+
*(
408+
"python",
409+
"-c",
410+
f'from fastapi import __version__; assert __version__ == "{fastapi_version}", __version__',
411+
)
412+
)
413+
414+
# Run pytest using the Nox-managed virtualenv (avoid external interpreter).
415+
session.run("pytest")
416+
417+
259418
@contextlib.contextmanager
260419
def alter_session(
261420
session: AlteredSession,
@@ -606,6 +765,42 @@ def ci(session: Session):
606765
test(session)
607766

608767

768+
@session(reuse_venv=False)
769+
def install_latest_tarball(session: Session):
770+
import glob
771+
import re
772+
773+
from packaging import version
774+
775+
# Get all tarball files
776+
tarball_files = glob.glob(f"{DIST_DIR}/{PROJECT_NAME_NORMALIZED}-*.tar.gz")
777+
778+
if not tarball_files:
779+
session.error("No tarball files found in dist/ directory")
780+
781+
# Extract version numbers using regex
782+
version_pattern = re.compile(
783+
rf"{PROJECT_NAME_NORMALIZED}-([0-9]+\.[0-9]+\.[0-9]+(?:\.[0-9]+)?(?:(?:a|b|rc)[0-9]+)?(?:\.post[0-9]+)?(?:\.dev[0-9]+)?).tar.gz"
784+
)
785+
786+
# Create a list of (file_path, version) tuples
787+
versioned_files = []
788+
for file_path in tarball_files:
789+
match = version_pattern.search(file_path)
790+
if match:
791+
ver_str = match.group(1)
792+
versioned_files.append((file_path, version.parse(ver_str)))
793+
794+
if not versioned_files:
795+
session.error("Could not extract version information from tarball files")
796+
797+
# Sort by version (highest first) and get the path
798+
latest_tarball = sorted(versioned_files, key=lambda x: x[1], reverse=True)[0][0]
799+
session.log(f"Installing latest version: {latest_tarball}")
800+
session.run("uv", "run", "pip", "uninstall", f"{PROJECT_NAME}", "-y")
801+
session.install(latest_tarball)
802+
803+
609804
@session(reuse_venv=False)
610805
def test_client_install_run(session: Session):
611806
with alter_session(session, dependency_group="dev"):

pyproject.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ classifiers = [
4545
]
4646
requires-python = ">=3.9"
4747
dependencies = [
48-
"fastapi>=0.100.1",
48+
"fastapi>=0.115.2",
4949
"typing-extensions>=4.0.0; python_version<'3.10'",
5050
]
5151
dynamic = []
@@ -116,7 +116,8 @@ reportMissingImports = "none"
116116
[dependency-groups]
117117
dev = [
118118
"bcrypt==4.3.0",
119-
"fastapi[standard]>=0.100.1",
119+
"email-validator>=2.3.0",
120+
"fastapi>=0.115.2",
120121
"httpx>=0.24.0",
121122
"isort>=6.0.1",
122123
"mypy>=1.18.2",

src/fastapi_shield/shield.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,13 +37,13 @@
3737

3838
from fastapi import HTTPException, Request, Response, status
3939
from fastapi._compat import _normalize_errors
40-
from fastapi.dependencies.utils import is_coroutine_callable
4140
from fastapi.exceptions import RequestValidationError
4241
from fastapi.params import Security
4342
from typing_extensions import Doc
4443

4544
# Import directly to make patching work correctly in tests
4645
import fastapi_shield.utils
46+
from fastapi_shield.utils import is_coroutine_callable
4747
from fastapi_shield.consts import (
4848
IS_SHIELDED_ENDPOINT_KEY,
4949
SHIELDED_ENDPOINT_PATH_FORMAT_KEY,

src/fastapi_shield/utils.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,23 +11,58 @@
1111
from collections.abc import Iterator
1212
from contextlib import AsyncExitStack
1313
from inspect import Parameter, signature
14+
import inspect
1415
from typing import Any, Callable, Optional, List, Union
1516

1617
from fastapi import HTTPException, Request, params
1718
from fastapi._compat import ModelField, Undefined
1819
from fastapi.dependencies.models import Dependant
1920
from fastapi.dependencies.utils import (
20-
_should_embed_body_fields,
2121
get_body_field,
2222
get_dependant,
2323
get_flat_dependant,
2424
solve_dependencies,
2525
)
26+
from pydantic import BaseModel
27+
from pydantic._internal._utils import lenient_issubclass
2628
from fastapi.exceptions import RequestValidationError
2729

2830
from starlette.routing import get_name
2931

3032

33+
# copied from `fastapi.dependencies.utils`
34+
def is_coroutine_callable(call: Callable[..., Any]) -> bool:
35+
if inspect.isroutine(call):
36+
return inspect.iscoroutinefunction(call)
37+
if inspect.isclass(call):
38+
return False
39+
dunder_call = getattr(call, "__call__", None) # noqa: B004
40+
return inspect.iscoroutinefunction(dunder_call)
41+
42+
43+
# copied from `fastapi.dependencies.utils`
44+
def _should_embed_body_fields(fields: List["ModelField"]) -> bool:
45+
if not fields:
46+
return False
47+
# More than one dependency could have the same field, it would show up as multiple
48+
# fields but it's the same one, so count them by name
49+
body_param_names_set = {field.name for field in fields}
50+
# A top level field has to be a single field, not multiple
51+
if len(body_param_names_set) > 1:
52+
return True
53+
first_field = fields[0]
54+
# If it explicitly specifies it is embedded, it has to be embedded
55+
if getattr(first_field.field_info, "embed", None):
56+
return True
57+
# If it's a Form (or File) field, it has to be a BaseModel to be top level
58+
# otherwise it has to be embedded, so that the key value pair can be extracted
59+
if isinstance(first_field.field_info, params.Form) and not lenient_issubclass(
60+
first_field.type_, BaseModel
61+
):
62+
return True
63+
return False
64+
65+
3166
def generate_unique_id_for_fastapi_shield(dependant: Dependant, path_format: str):
3267
"""Generate a unique identifier for FastAPI Shield dependants.
3368

tests/test_basics.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -218,16 +218,8 @@ def test_unprotected_endpoint():
218218
client = TestClient(app)
219219
response = client.get("/unprotected")
220220
assert response.status_code == 200
221-
assert response.json() == {
222-
"message": "This is an unprotected endpoint",
223-
"user": {
224-
"dependency": {},
225-
"use_cache": True,
226-
"scopes": [],
227-
"shielded_dependency": {},
228-
"unblocked": False,
229-
},
230-
}, response.json()
221+
result_json = response.json()
222+
assert result_json["message"] == "This is an unprotected endpoint", response.json()
231223

232224

233225
def test_protected_endpoint_without_token():

tests/test_basics_three.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -300,16 +300,8 @@ def test_unprotected_endpoint():
300300
client = TestClient(app)
301301
response = client.get("/unprotected")
302302
assert response.status_code == 200
303-
assert response.json() == {
304-
"message": "This is an unprotected endpoint",
305-
"user": {
306-
"dependency": {},
307-
"use_cache": True,
308-
"scopes": [],
309-
"shielded_dependency": {},
310-
"unblocked": False,
311-
},
312-
}, response.json()
303+
result_json = response.json()
304+
assert result_json["message"] == "This is an unprotected endpoint", response.json()
313305

314306

315307
def test_protected_endpoint_without_token():

tests/test_basics_two.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -300,16 +300,8 @@ def test_unprotected_endpoint():
300300
client = TestClient(app)
301301
response = client.get("/unprotected")
302302
assert response.status_code == 200
303-
assert response.json() == {
304-
"message": "This is an unprotected endpoint",
305-
"user": {
306-
"dependency": {},
307-
"use_cache": True,
308-
"scopes": [],
309-
"shielded_dependency": {},
310-
"unblocked": False,
311-
},
312-
}, response.json()
303+
result_json = response.json()
304+
assert result_json["message"] == "This is an unprotected endpoint", response.json()
313305

314306

315307
def test_protected_endpoint_without_token():

0 commit comments

Comments
 (0)