Skip to content

Commit 1cb1d66

Browse files
fixup! Add support for Fast Stream Depends
1 parent 575d0e4 commit 1cb1d66

File tree

1 file changed

+23
-32
lines changed

1 file changed

+23
-32
lines changed

src/dependency_injector/wiring.py

Lines changed: 23 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -47,11 +47,19 @@ def get_args(hint):
4747
def get_origin(tp):
4848
return None
4949

50+
MARKER_EXTRACTORS = []
5051

5152
try:
52-
import fastapi.params
53+
from fastapi.params import Depends as FastApiDepends
5354
except ImportError:
54-
fastapi = None
55+
pass
56+
else:
57+
def extract_marker_from_fastapi(param: Any) -> Any:
58+
if isinstance(param, FastApiDepends):
59+
return param.dependency
60+
return None
61+
62+
MARKER_EXTRACTORS.append(extract_marker_from_fastapi)
5563

5664

5765
try:
@@ -67,10 +75,16 @@ def get_origin(tp):
6775

6876

6977
try:
70-
import fast_depends.dependencies
78+
from fast_depends.dependencies import Depends as FastDepends
7179
except ImportError:
72-
fast_depends = None
80+
pass
81+
else:
82+
def extract_marker_from_fast_depends(param: Any) -> Any:
83+
if isinstance(param, FastDepends):
84+
return param.dependency
85+
return None
7386

87+
MARKER_EXTRACTORS.append(extract_marker_from_fast_depends)
7488

7589
from . import providers
7690

@@ -102,23 +116,6 @@ def get_origin(tp):
102116
else:
103117
Container = Any
104118

105-
def _is_fastapi_depends(param: Any) -> bool:
106-
return fastapi and isinstance(param, fastapi.params.Depends)
107-
108-
109-
if fast_depends:
110-
def _is_fast_stream_depends(param: Any) -> bool:
111-
return isinstance(param, fast_depends.dependencies.Depends)
112-
else:
113-
def _is_fast_stream_depends(param: Any) -> bool:
114-
return False
115-
116-
117-
_DEPENDS_CHECKERS = (
118-
_is_fastapi_depends,
119-
_is_fast_stream_depends,
120-
)
121-
122119

123120
class PatchedRegistry:
124121

@@ -606,8 +603,6 @@ def _unpatch_attribute(patched: PatchedAttribute) -> None:
606603

607604

608605
def _extract_marker(parameter: inspect.Parameter) -> Optional["_Marker"]:
609-
depends_available = False
610-
611606
if get_origin(parameter.annotation) is Annotated:
612607
args = get_args(parameter.annotation)
613608
if len(args) > 1:
@@ -617,18 +612,14 @@ def _extract_marker(parameter: inspect.Parameter) -> Optional["_Marker"]:
617612
else:
618613
marker = parameter.default
619614

620-
if any(depends_checker(marker) for depends_checker in _DEPENDS_CHECKERS):
621-
depends_available = True
615+
for marker_extractor in MARKER_EXTRACTORS:
616+
if _marker := marker_extractor(marker):
617+
marker = _marker
618+
break
622619

623-
if not isinstance(marker, _Marker) and not depends_available:
620+
if not isinstance(marker, _Marker):
624621
return None
625622

626-
if depends_available:
627-
marker = marker.dependency
628-
629-
if not isinstance(marker, _Marker):
630-
return None
631-
632623
return marker
633624

634625

0 commit comments

Comments
 (0)