Skip to content

Commit e6d0f97

Browse files
committed
Add BoundProxy helper class to bulletproof direct function calls and remove threading dependency
1 parent 5e676a9 commit e6d0f97

File tree

2 files changed

+53
-56
lines changed

2 files changed

+53
-56
lines changed

django_typer/__init__.py

Lines changed: 53 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,6 @@
8080

8181
import inspect
8282
import sys
83-
import threading
8483
import typing as t
8584
from copy import copy, deepcopy
8685
from functools import cached_property
@@ -793,14 +792,17 @@ def register(
793792
setattr(callback, _CACHE_KEY, register)
794793

795794

795+
TyperFunction = t.Union[
796+
"Typer[P, R]",
797+
typer.models.CommandInfo,
798+
typer.models.TyperInfo,
799+
t.Callable[..., t.Any],
800+
]
801+
802+
796803
def _get_direct_function(
797804
obj: "TyperCommand",
798-
app_node: t.Union[
799-
"Typer",
800-
typer.models.CommandInfo,
801-
typer.models.TyperInfo,
802-
t.Callable[..., t.Any],
803-
],
805+
app_node: TyperFunction,
804806
):
805807
"""
806808
Get a direct callable function bound to the given object if it is not static held by the given
@@ -872,8 +874,6 @@ class Typer(typer.Typer, t.Generic[P, R], metaclass=AppFactory):
872874
registered_commands: t.List[typer.models.CommandInfo] = []
873875
registered_callback: t.Optional[typer.models.TyperInfo] = None
874876

875-
_local = threading.local()
876-
877877
is_method: t.Optional[bool] = None
878878
top_level: bool = False
879879

@@ -885,16 +885,6 @@ def django_command(self) -> t.Optional[t.Type["TyperCommand"]]:
885885
def django_command(self, cmd: t.Optional[t.Type["TyperCommand"]]):
886886
self._django_command = cmd
887887

888-
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R:
889-
"""
890-
Typers more than one level deep will route invocations through this
891-
function which wraps our initializer.
892-
"""
893-
cmd = self.cmd_obj()
894-
if self.parent and cmd: # don't call direct if root app
895-
return _get_direct_function(cmd, self)(*args, **kwargs)
896-
return super().__call__(*args, **kwargs)
897-
898888
# todo - this results in type hinting expecting self to be passed explicitly
899889
# when this is called as a callable
900890
# https://github.com/bckohan/django-typer/issues/73
@@ -907,18 +897,16 @@ def __get__(self, obj, _=None) -> "Typer[P, R]":
907897
on the class and subclasses.
908898
"""
909899
if isinstance(obj, TyperCommand):
910-
self._local.object = obj
911-
else:
912-
self._local.object = None
900+
return t.cast(Typer[P, R], BoundProxy(obj, self))
913901
return self
914902

915903
def __getattr__(self, name: str) -> t.Any:
916-
cmd_obj = self.cmd_obj()
904+
if isinstance(attr := getattr(self.__class__, name, None), property):
905+
return t.cast(t.Callable, attr.fget)(self)
906+
917907
for cmd in self.registered_commands:
918908
assert cmd.callback
919909
if name in (cmd.callback.__name__, cmd.name):
920-
if cmd_obj:
921-
return _get_direct_function(cmd_obj, cmd)
922910
return cmd
923911
for grp in self.registered_groups:
924912
cmd_grp = t.cast(Typer, grp.typer_instance)
@@ -935,25 +923,6 @@ def __getattr__(self, name: str) -> t.Any:
935923
)
936924
)
937925

938-
def cmd_obj(self) -> t.Optional["TyperCommand"]:
939-
"""
940-
If this command group was ultimately accessed from a TyperCommand instance,
941-
get that instance. For instance:
942-
943-
.. code-block:: python
944-
945-
cmd = Command()
946-
assert cmd.lvl1.lvl2.cmd_obj() is cmd
947-
948-
This enables namespaced direct calls that work despite group or command
949-
name collisions.
950-
"""
951-
assert self.parent is None or isinstance(self.parent, Typer)
952-
obj = self._local.object or (
953-
self.parent.cmd_obj() if isinstance(self.parent, Typer) else None
954-
)
955-
return obj if isinstance(obj, TyperCommand) else None
956-
957926
def __init__(
958927
self,
959928
*args,
@@ -988,7 +957,6 @@ def __init__(
988957
assert not args # should have been removed by metaclass
989958
self.parent = parent
990959
self._django_command = django_command
991-
self._local.object = None
992960
self.top_level = kwargs.pop("top_level", False)
993961
typer_app = kwargs.pop("typer_app", None)
994962
callback = _strip_static(callback)
@@ -1334,6 +1302,44 @@ def create_app(func: t.Callable[P2, R2]) -> Typer[P2, R2]:
13341302
return create_app
13351303

13361304

1305+
class BoundProxy(t.Generic[P, R]):
1306+
"""
1307+
A helper class that proxies the Typer or command objects and binds them
1308+
to the django command instance.
1309+
"""
1310+
1311+
command: "TyperCommand"
1312+
proxied: TyperFunction
1313+
1314+
def __init__(self, command: "TyperCommand", proxied: TyperFunction):
1315+
self.command = command
1316+
self.proxied = proxied
1317+
1318+
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R:
1319+
if isinstance(self.proxied, Typer) and not self.proxied.parent:
1320+
# if we're calling a top level Typer app we need invoke Typer's call
1321+
return self.proxied(*args, **kwargs)
1322+
return _get_direct_function(self.command, self.proxied)(*args, **kwargs)
1323+
1324+
def __getattr__(self, name: str) -> t.Any:
1325+
"""
1326+
If our proxied object __getattr__ returns a Typer or Command object we
1327+
wrap it in a BoundProxy so that it can be called directly as a method
1328+
on the django command instance.
1329+
"""
1330+
if hasattr(self.proxied, name):
1331+
attr = getattr(self.proxied, name)
1332+
if isinstance(attr, (Typer, typer.models.CommandInfo)):
1333+
return BoundProxy(self.command, attr)
1334+
return attr
1335+
1336+
raise AttributeError(
1337+
"{cls} object has no attribute {name}".format(
1338+
cls=self.__class__.__name__, name=name
1339+
)
1340+
)
1341+
1342+
13371343
def initialize(
13381344
name: t.Optional[str] = Default(None),
13391345
*,
@@ -2958,16 +2964,10 @@ def __getattr__(self, name: str) -> t.Any:
29582964
self.typer_app.info.callback,
29592965
)
29602966
if init and init and name == init.__name__:
2961-
return MethodType(init, self) if is_method(init) else staticmethod(init)
2967+
return BoundProxy(self, init)
29622968
found = depth_first_match(self.typer_app, name)
29632969
if found:
2964-
if isinstance(found, Typer):
2965-
# todo shouldn't be needed - wrap these in a proxy,
2966-
# avoid need for threading local
2967-
found._local.object = self
2968-
else:
2969-
return _get_direct_function(self, found)
2970-
return found
2970+
return BoundProxy(self, found)
29712971
raise AttributeError(
29722972
"{cls} object has no attribute {name}".format(
29732973
cls=self.__class__.__name__, name=name

django_typer/tests/test_overloads.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -230,9 +230,6 @@ def test_grp_overload_direct(self):
230230
self.assertEqual(grp_overload.g1.l2.cmd(), "g1:l2:cmd()")
231231
print("----------")
232232

233-
self.assertEqual(grp_overload.g0.l2.cmd_obj(), grp_overload)
234-
self.assertEqual(grp_overload.g1.l2.cmd_obj(), grp_overload)
235-
236233
self.assertTrue(hasattr(grp_overload.g0.l2, "cmd2"))
237234
self.assertEqual(grp_overload.g0.l2.cmd2(), "g0:l2:cmd2()")
238235
self.assertEqual(grp_overload.g1.l2.cmd2(), "g1:l2:cmd2()")

0 commit comments

Comments
 (0)