|
11 | 11 | from operator import methodcaller |
12 | 12 | from typing_extensions import Annotated |
13 | 13 | from dataclasses import field, dataclass |
14 | | -from typing import Any, TypeVar, Coroutine |
15 | 14 | from inspect import Parameter, Signature, isclass |
16 | 15 | from collections.abc import Callable, Iterable, Generator |
| 16 | +from typing import TYPE_CHECKING, Any, TypeVar, Coroutine |
17 | 17 | from importlib.metadata import Distribution, PackageNotFoundError, distribution |
18 | 18 |
|
19 | 19 | import click |
|
36 | 36 | from typing_extensions import ParamSpec, get_args, get_origin |
37 | 37 |
|
38 | 38 |
|
| 39 | +if TYPE_CHECKING: |
| 40 | + from . import async_scoped_session |
| 41 | + |
| 42 | + |
39 | 43 | _T = TypeVar("_T") |
40 | 44 | _P = ParamSpec("_P") |
41 | 45 |
|
@@ -73,60 +77,67 @@ def write(self, buffer: str): |
73 | 77 | while frame and frame.f_code.co_name != "print_stdout": |
74 | 78 | frame = frame.f_back |
75 | 79 | depth += 1 |
76 | | - depth += 1 |
77 | 80 |
|
78 | 81 | for line in buffer.rstrip().splitlines(): |
79 | | - logger.opt(depth=depth).log(self._level, line.rstrip()) |
| 82 | + logger.opt(depth=depth + 1).log(self._level, line.rstrip()) |
80 | 83 |
|
81 | 84 | def flush(self): |
82 | 85 | pass |
83 | 86 |
|
84 | 87 |
|
85 | | -@dataclass |
| 88 | +@dataclass(unsafe_hash=True) |
86 | 89 | class Option: |
87 | 90 | stream: bool = True |
88 | 91 | scalars: bool = False |
89 | 92 | result: methodcaller | None = None |
90 | | - calls: list[methodcaller] = field(default_factory=list) |
| 93 | + calls: tuple[methodcaller] = field(default_factory=tuple) |
91 | 94 |
|
92 | 95 |
|
93 | | -def compile_dependency(statement: ExecutableReturnsRows, option: Option) -> Any: |
94 | | - from . import async_scoped_session |
| 96 | +@dataclass |
| 97 | +class Dependency: |
| 98 | + __signature__: Signature = field(init=False) |
| 99 | + |
| 100 | + statement: ExecutableReturnsRows |
| 101 | + option: Option |
| 102 | + |
| 103 | + def __post_init__(self) -> None: |
| 104 | + from . import async_scoped_session |
| 105 | + |
| 106 | + self.__signature__ = Signature( |
| 107 | + [ |
| 108 | + Parameter( |
| 109 | + "_session", Parameter.KEYWORD_ONLY, annotation=async_scoped_session |
| 110 | + ), |
| 111 | + *( |
| 112 | + Parameter(name, Parameter.KEYWORD_ONLY, default=depends) |
| 113 | + for name, depends in self.statement.compile().params.items() |
| 114 | + if isinstance(depends, DependsInner) |
| 115 | + ), |
| 116 | + ] |
| 117 | + ) |
95 | 118 |
|
96 | | - async def __dependency(*, __session: async_scoped_session, **params: Any): |
97 | | - if option.stream: |
98 | | - result = await __session.stream(statement, params) |
| 119 | + async def __call__(self, *, _session: async_scoped_session, **params: Any) -> Any: |
| 120 | + if self.option.stream: |
| 121 | + result = await _session.stream(self.statement, params) |
99 | 122 | else: |
100 | | - result = await __session.execute(statement, params) |
| 123 | + result = await _session.execute(self.statement, params) |
101 | 124 |
|
102 | | - for call in option.calls: |
| 125 | + for call in self.option.calls: |
103 | 126 | result = call(result) |
104 | 127 |
|
105 | | - if option.scalars: |
| 128 | + if self.option.scalars: |
106 | 129 | result = result.scalars() |
107 | 130 |
|
108 | | - if call := option.result: |
| 131 | + if call := self.option.result: |
109 | 132 | result = call(result) |
110 | 133 |
|
111 | | - if option.stream: |
| 134 | + if self.option.stream: |
112 | 135 | result = await result |
113 | 136 |
|
114 | 137 | return result |
115 | 138 |
|
116 | | - __dependency.__signature__ = Signature( |
117 | | - [ |
118 | | - Parameter( |
119 | | - "__session", Parameter.KEYWORD_ONLY, annotation=async_scoped_session |
120 | | - ), |
121 | | - *( |
122 | | - Parameter(name, Parameter.KEYWORD_ONLY, default=depends) |
123 | | - for name, depends in statement.compile().params.items() |
124 | | - if isinstance(depends, DependsInner) |
125 | | - ), |
126 | | - ] |
127 | | - ) |
128 | | - |
129 | | - return Depends(__dependency) |
| 139 | + def __hash__(self) -> int: |
| 140 | + return hash((self.statement, self.option)) |
130 | 141 |
|
131 | 142 |
|
132 | 143 | def generic_issubclass(scls: Any, cls: Any) -> Any: |
|
0 commit comments