Skip to content

Commit cc73a0f

Browse files
committed
⚰️ refactor(sqla): remove utils.methodcaller
1 parent b732577 commit cc73a0f

File tree

3 files changed

+15
-34
lines changed

3 files changed

+15
-34
lines changed

nonebot_plugin_orm/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ def _create_engine(engine: str | URL | dict[str, Any] | AsyncEngine) -> AsyncEng
166166
else:
167167
url = engine
168168

169-
return create_async_engine(make_url(url), **options)
169+
return create_async_engine(url, **options)
170170

171171

172172
def _init_engines():
@@ -204,7 +204,7 @@ def _init_table():
204204
_plugins = {}
205205

206206
_get_plugin_by_module_name = lru_cache(None)(get_plugin_by_module_name)
207-
for model in get_subclasses(Model):
207+
for model in set(get_subclasses(Model)):
208208
table: Table | None = getattr(model, "__table__", None)
209209

210210
if table is None or (bind_key := table.info.get("bind_key")) is None:
@@ -229,7 +229,7 @@ def _init_logger():
229229
"sqlalchemy": log_level,
230230
**{
231231
_qual_logger_name_for_cls(cls): echo_log_level
232-
for cls in get_subclasses(Identified)
232+
for cls in set(get_subclasses(Identified))
233233
},
234234
}
235235

nonebot_plugin_orm/param.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from itertools import repeat
55
from typing import Any, cast
66
from dataclasses import dataclass
7+
from operator import methodcaller
78
from inspect import Parameter, isclass
89

910
from pydantic.fields import FieldInfo
@@ -15,7 +16,7 @@
1516
from sqlalchemy.ext.asyncio import AsyncResult, AsyncScalarResult
1617

1718
from .model import Model
18-
from .utils import Option, methodcaller, compile_dependency, generic_issubclass
19+
from .utils import Option, compile_dependency, generic_issubclass
1920

2021
if sys.version_info >= (3, 9):
2122
from typing import Annotated

nonebot_plugin_orm/utils.py

Lines changed: 10 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from itertools import repeat
1010
from contextlib import suppress
1111
from typing import Any, TypeVar
12+
from operator import methodcaller
1213
from typing_extensions import Annotated
1314
from dataclasses import field, dataclass
1415
from inspect import Parameter, Signature, isclass
@@ -95,29 +96,10 @@ class Option:
9596
calls: list[methodcaller] = field(default_factory=list)
9697

9798

98-
class methodcaller:
99-
__slots__ = ("_name", "_args", "_kwargs")
100-
101-
def __init__(self, name, /, *args, **kwargs):
102-
self._name = name
103-
if not isinstance(self._name, str):
104-
raise TypeError("method name must be a string")
105-
self._args = args
106-
self._kwargs = kwargs
107-
108-
def __call__(self, obj):
109-
return getattr(obj, self._name)(*self._args, **self._kwargs)
110-
111-
def __eq__(self, value: object, /) -> bool:
112-
return isinstance(value, methodcaller) and all(
113-
getattr(self, attr) == getattr(value, attr) for attr in self.__slots__
114-
)
115-
116-
11799
def compile_dependency(statement: ExecutableReturnsRows, option: Option) -> Any:
118100
from . import async_scoped_session
119101

120-
async def dependency(*, __session: async_scoped_session, **params: Any):
102+
async def __dependency(*, __session: async_scoped_session, **params: Any):
121103
if option.stream:
122104
result = await __session.stream(statement, params)
123105
else:
@@ -137,7 +119,7 @@ async def dependency(*, __session: async_scoped_session, **params: Any):
137119

138120
return result
139121

140-
dependency.__signature__ = Signature(
122+
__dependency.__signature__ = Signature(
141123
[
142124
Parameter(
143125
"__session", Parameter.KEYWORD_ONLY, annotation=async_scoped_session
@@ -150,7 +132,7 @@ async def dependency(*, __session: async_scoped_session, **params: Any):
150132
]
151133
)
152134

153-
return Depends(dependency)
135+
return Depends(__dependency)
154136

155137

156138
def generic_issubclass(scls: Any, cls: Any) -> Any:
@@ -239,16 +221,16 @@ def is_editable(plugin: Plugin) -> bool:
239221
with suppress(PackageNotFoundError):
240222
dist = distribution(plugin.name.replace("_", "-"))
241223

242-
if not (dist or plugin.module.__file__ is None):
224+
if not dist and plugin.module.__file__:
243225
path = Path(plugin.module.__file__)
244226
for name in pkgs.get(plugin.module_name.split(".")[0], ()):
245227
dist = distribution(name)
246-
if path in map(methodcaller("locate"), dist.files or ()):
228+
if path in (file.locate() for file in dist.files or ()):
247229
break
248230
else:
249231
dist = None
250232

251-
if dist is None:
233+
if not dist:
252234
return True
253235

254236
# https://github.com/pdm-project/pdm/blob/fee1e6bffd7de30315e2134e19f9a6f58e15867c/src/pdm/utils.py#L361-L374
@@ -263,12 +245,10 @@ def is_editable(plugin: Plugin) -> bool:
263245
return direct_url_data.get("dir_info", {}).get("editable", False)
264246

265247

266-
def get_subclasses(cls: type[_T]) -> set[type[_T]]:
267-
subclasses = set()
248+
def get_subclasses(cls: type[_T]) -> Generator[type[_T], None, None]:
249+
yield from cls.__subclasses__()
268250
for subclass in cls.__subclasses__():
269-
subclasses.add(subclass)
270-
subclasses.update(get_subclasses(subclass))
271-
return subclasses
251+
yield from get_subclasses(subclass)
272252

273253

274254
if sys.version_info >= (3, 10):

0 commit comments

Comments
 (0)