Skip to content

Commit 0e7adb3

Browse files
authored
Feature/bind by type handling generic (#110)
1 parent e88c8f9 commit 0e7adb3

File tree

4 files changed

+69
-9
lines changed

4 files changed

+69
-9
lines changed

.github/workflows/workflow.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ jobs:
4242
os: ubuntu-latest
4343
- python: "3.10"
4444
os: ubuntu-latest
45-
- python: "3.11.0-beta.1 - 3.11"
45+
- python: "3.11"
4646
os: ubuntu-latest
4747
# test OSs
4848
- python: "3.x"

di/_container.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,15 +81,15 @@ def bind_by_type(
8181
def hook(
8282
param: inspect.Parameter | None, dependent: DependentBase[Any]
8383
) -> DependentBase[Any] | None:
84-
if dependent.call is dependency:
84+
if dependent.call == dependency:
8585
return provider
8686
if param is None:
8787
return None
8888
type_annotation_option = get_type(param)
8989
if type_annotation_option is None:
9090
return None
9191
type_annotation = type_annotation_option.value
92-
if type_annotation is dependency:
92+
if type_annotation == dependency:
9393
return provider
9494
if covariant:
9595
if inspect.isclass(type_annotation) and inspect.isclass(dependency):

pyproject.toml

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "di"
3-
version = "0.77.0"
3+
version = "0.78.0"
44
description = "Dependency injection toolkit"
55
authors = ["Adrian Garcia Badaracco <adrian@adriangb.com>"]
66
readme = "README.md"
@@ -38,6 +38,7 @@ anyio = ["anyio"]
3838
# linting
3939
black = "~23"
4040
mypy = "~1"
41+
ruff = "^0.0.286"
4142
pre-commit = "~2"
4243
# testing
4344
pytest = "~7"
@@ -48,18 +49,15 @@ coverage = { extras = ["toml"], version = "~6" }
4849
# docs
4950
mkdocs = "~1"
5051
mkdocs-material = "~8,!=8.1.3"
52+
mkdocstrings = {version = "^0.19.0", extras = ["python"]}
5153
mike = "~1"
5254
# benchmarking
5355
pyinstrument = "~4"
54-
mkdocstrings = {version = "^0.19.0", extras = ["python"]}
55-
ruff = "^0.0.286"
5656

5757
[build-system]
5858
requires = ["poetry-core"]
5959
build-backend = "poetry.core.masonry.api"
6060

61-
[tool.isort]
62-
profile = "black"
6361

6462
[tool.coverage.run]
6563
branch = true

tests/test_binding.py

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
from typing import List
1+
import sys
2+
from abc import abstractmethod
3+
from typing import List, TypeVar
24

35
import pytest
46

@@ -7,6 +9,11 @@
79
from di.executors import SyncExecutor
810
from di.typing import Annotated
911

12+
if sys.version_info < (3, 8): # pragma: no cover
13+
from typing_extensions import Protocol
14+
else: # pragma: no cover
15+
from typing import Protocol
16+
1017

1118
class Request:
1219
def __init__(self, value: int = 0) -> None:
@@ -47,6 +54,61 @@ def __init__(self, v: int = 1) -> None:
4754
assert res.v == 1
4855

4956

57+
T_co = TypeVar("T_co", covariant=True)
58+
59+
60+
def test_bind_generic():
61+
container = Container()
62+
executor = SyncExecutor()
63+
expected = 100
64+
65+
class GetterInterface(Protocol[T_co]):
66+
@abstractmethod
67+
def get(self) -> T_co:
68+
...
69+
70+
class GetterIntImpl(GetterInterface[int]):
71+
def __init__(self, v: int) -> None:
72+
self.v = v
73+
74+
def get(self) -> int:
75+
return self.v
76+
77+
def factory() -> GetterIntImpl:
78+
return GetterIntImpl(expected)
79+
80+
hook = bind_by_type(
81+
Dependent(factory),
82+
GetterInterface[int],
83+
)
84+
container.bind(hook)
85+
86+
# ===========================================
87+
# clean `_tp_cache`
88+
from typing import _cleanups as cache_cleanups # type: ignore[attr-defined]
89+
90+
for cache_cleanup in cache_cleanups:
91+
cache_cleanup()
92+
# ===========================================
93+
94+
class IntService:
95+
"""Declared after binding and cache clearing."""
96+
97+
def __init__(self, getter: GetterInterface[int]) -> None:
98+
self.getter = getter
99+
100+
scopes = [None]
101+
flat_dependent = Dependent(GetterInterface[int])
102+
wired_dependent = Dependent(IntService)
103+
with container.enter_scope(None) as state:
104+
flat_solved = container.solve(flat_dependent, scopes)
105+
wired_solved = container.solve(wired_dependent, scopes)
106+
flat = flat_solved.execute_sync(executor, state)
107+
wired = wired_solved.execute_sync(executor, state)
108+
109+
assert flat.get() == wired.getter.get() == expected
110+
111+
50112
def test_bind_transitive_dependency_results_skips_subdpendencies():
51113
"""If we bind a transitive dependency none of it's sub-dependencies should be executed
52114
since they are no longer required.

0 commit comments

Comments
 (0)