Skip to content

Commit 5e676a9

Browse files
committed
rework CommandNode tree to be lazily loaded
1 parent fe4613b commit 5e676a9

File tree

3 files changed

+80
-131
lines changed

3 files changed

+80
-131
lines changed

django_typer/__init__.py

Lines changed: 72 additions & 125 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@
8383
import threading
8484
import typing as t
8585
from copy import copy, deepcopy
86+
from functools import cached_property
8687
from importlib import import_module
8788
from pathlib import Path
8889
from types import MethodType, SimpleNamespace
@@ -270,7 +271,7 @@ def get_command( # type: ignore[overload-overlap]
270271
no_color: bool = False,
271272
force_color: bool = False,
272273
**kwargs,
273-
) -> CallableCommand: ...
274+
) -> BaseCommand: ...
274275

275276

276277
@t.overload # pragma: no cover
@@ -794,7 +795,12 @@ def register(
794795

795796
def _get_direct_function(
796797
obj: "TyperCommand",
797-
app_node: t.Union["Typer", typer.models.CommandInfo, typer.models.TyperInfo],
798+
app_node: t.Union[
799+
"Typer",
800+
typer.models.CommandInfo,
801+
typer.models.TyperInfo,
802+
t.Callable[..., t.Any],
803+
],
798804
):
799805
"""
800806
Get a direct callable function bound to the given object if it is not static held by the given
@@ -803,8 +809,11 @@ def _get_direct_function(
803809
if isinstance(app_node, Typer):
804810
method = app_node.is_method
805811
cb = getattr(app_node.registered_callback, "callback", app_node.info.callback)
812+
elif cb := getattr(app_node, "callback", None):
813+
method = is_method(cb)
806814
else:
807-
cb = app_node.callback
815+
assert callable(app_node)
816+
cb = app_node
808817
method = is_method(cb)
809818
assert cb
810819
return MethodType(cb, obj) if method else staticmethod(cb)
@@ -886,36 +895,10 @@ def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R:
886895
return _get_direct_function(cmd, self)(*args, **kwargs)
887896
return super().__call__(*args, **kwargs)
888897

889-
@t.overload # pragma: no cover
890-
def __get__(
891-
self, obj: "TyperCommandMeta", owner: t.Type["TyperCommandMeta"]
892-
) -> "Typer[P, R]": ...
893-
894-
@t.overload # pragma: no cover
895-
def __get__(
896-
self,
897-
obj: None,
898-
owner: t.Type["TyperCommand"],
899-
) -> "Typer[P, R]": ...
900-
901-
@t.overload # pragma: no cover
902-
def __get__(
903-
self,
904-
obj: "Typer",
905-
owner: t.Type["Typer"],
906-
) -> "Typer[P, R]": ...
907-
908-
@t.overload # pragma: no cover
909-
def __get__(
910-
self, obj: "TyperCommand", owner: t.Any = None
911-
) -> "Typer[P, R]": # t.Union[MethodType, t.Callable[P, R]]
912-
# todo - we could return the generic callable type here but the problem
913-
# is self is included in the ParamSpec and it seems tricky to remove?
914-
# MethodType loses the parameters but is preferable to type checking errors
915-
# https://github.com/bckohan/django-typer/issues/73
916-
...
917-
918-
def __get__(self, obj, owner=None): # pyright: ignore[reportInconsistentOverload]
898+
# todo - this results in type hinting expecting self to be passed explicitly
899+
# when this is called as a callable
900+
# https://github.com/bckohan/django-typer/issues/73
901+
def __get__(self, obj, _=None) -> "Typer[P, R]":
919902
"""
920903
Our Typer app wrapper also doubles as a descriptor, so when
921904
it is accessed on the instance, we return the wrapped function
@@ -2113,7 +2096,7 @@ class CommandNode:
21132096
The click command object that this node represents.
21142097
"""
21152098

2116-
context: TyperContext
2099+
context: Context
21172100
"""
21182101
The Typer context object used to run this command.
21192102
"""
@@ -2123,38 +2106,50 @@ class CommandNode:
21232106
Back reference to the django command instance that this command belongs to.
21242107
"""
21252108

2126-
parent: t.Optional["CommandNode"] = None
2127-
"""
2128-
The parent node of this command node or None if this is a root node.
2129-
"""
2130-
2131-
children: t.Dict[str, "CommandNode"]
2132-
"""
2133-
The child group and command nodes of this command node.
2134-
"""
2109+
@cached_property
2110+
def children(self) -> t.Dict[str, "CommandNode"]:
2111+
"""
2112+
The child group and command nodes of this command node.
2113+
"""
2114+
return {
2115+
name: CommandNode(name, cmd, self.django_command, parent=self.context)
2116+
for name, cmd in getattr(
2117+
self.context.command,
2118+
"commands",
2119+
{
2120+
name: self.context.command.get_command(self.context, name) # type: ignore[attr-defined]
2121+
for name in (
2122+
self.context.command.list_commands(self.context)
2123+
if isinstance(self.context.command, click.MultiCommand)
2124+
else []
2125+
)
2126+
},
2127+
).items()
2128+
}
21352129

21362130
@property
21372131
def callback(self) -> t.Callable[..., t.Any]:
21382132
"""Get the function for this command or group"""
2139-
cb = getattr(self.click_command._callback, "__wrapped__")
2140-
return (
2141-
MethodType(cb, self.django_command) if self.click_command.is_method else cb
2133+
return _get_direct_function(
2134+
self.django_command, getattr(self.click_command._callback, "__wrapped__")
21422135
)
21432136

21442137
def __init__(
21452138
self,
21462139
name: str,
21472140
click_command: DjangoTyperMixin,
2148-
context: TyperContext,
21492141
django_command: "TyperCommand",
2150-
parent: t.Optional["CommandNode"] = None,
2142+
parent: t.Optional[Context] = None,
21512143
):
21522144
self.name = name
21532145
self.click_command = click_command
2154-
self.context = context
21552146
self.django_command = django_command
2156-
self.parent = parent
2157-
self.children = {}
2147+
self.context = Context(
2148+
self.click_command,
2149+
info_name=name,
2150+
django_command=django_command,
2151+
parent=parent,
2152+
)
21582153

21592154
def print_help(self) -> t.Optional[str]:
21602155
"""
@@ -2235,9 +2230,12 @@ def nargs(self) -> int:
22352230
@property
22362231
def option_strings(self) -> t.List[str]:
22372232
"""
2238-
The list of allowable command line option strings for this parameter.
2233+
call_command uses this to determine a mapping of supplied options to function
2234+
arguments. I.e. it will remap option_string: dest. We don't want this because
2235+
we'd rather have supplied parameters line up with their function arguments to
2236+
allow deconfliction when CLI options share the same name.
22392237
"""
2240-
return list(self.param.opts) if isinstance(self.param, click.Option) else []
2238+
return []
22412239

22422240
_actions: t.List[t.Any]
22432241
_mutually_exclusive_groups: t.List[t.Any] = []
@@ -2251,23 +2249,22 @@ def __init__(self, django_command: "TyperCommand", prog_name, subcommand):
22512249
self.django_command = django_command
22522250
self.prog_name = prog_name
22532251
self.subcommand = subcommand
2252+
self.tree = self.django_command.command_tree
2253+
self.tree.context.info_name = f"{self.prog_name} {self.subcommand}"
22542254

22552255
def populate_params(node: CommandNode) -> None:
22562256
for param in node.click_command.params:
22572257
self._actions.append(self.Action(param))
22582258
for child in node.children.values():
22592259
populate_params(child)
22602260

2261-
populate_params(self.django_command.command_tree)
2261+
populate_params(self.tree)
22622262

22632263
def print_help(self, *command_path: str):
22642264
"""
22652265
Print the help for the given command path to stdout of the django command.
22662266
"""
2267-
self.django_command.command_tree.context.info_name = (
2268-
f"{self.prog_name} {self.subcommand}"
2269-
)
2270-
command_node = self.django_command.get_subcommand(*command_path)
2267+
command_node = self.tree.get_command(*command_path)
22712268
hlp = command_node.print_help()
22722269
if hlp:
22732270
self.django_command.stdout.write(
@@ -2470,12 +2467,22 @@ def command2(self, option: t.Optional[str] = None):
24702467

24712468
help: t.Optional[t.Union[DefaultPlaceholder, str]] = Default(None) # type: ignore
24722469

2473-
command_tree: CommandNode
2474-
24752470
# allow deriving commands to override handle() from BaseCommand
24762471
# without triggering static type checking complaints
24772472
handle = None # type: ignore
24782473

2474+
@property
2475+
def command_tree(self) -> CommandNode:
2476+
"""
2477+
Get the root CommandNode for this command. Allows easy traversal of the command
2478+
tree.
2479+
"""
2480+
return CommandNode(
2481+
f"{sys.argv[0]} {self._name}",
2482+
click_command=t.cast(DjangoTyperMixin, get_typer_command(self.typer_app)),
2483+
django_command=self,
2484+
)
2485+
24792486
@classmethod
24802487
def initialize(
24812488
cmd, # pyright: ignore[reportSelfClsParameterName]
@@ -2879,9 +2886,7 @@ def __init__(
28792886
self.stdout.style_func = stdout_style_func
28802887
self.stderr.style_func = stderr_style_func
28812888
try:
2882-
self.command_tree = self._build_cmd_tree(
2883-
get_typer_command(self.typer_app)
2884-
)
2889+
assert get_typer_command(self.typer_app)
28852890
except RuntimeError as rerr:
28862891
raise NotImplementedError(
28872892
_(
@@ -2900,67 +2905,6 @@ def get_subcommand(self, *command_path: str) -> CommandNode:
29002905
"""
29012906
return self.command_tree.get_command(*command_path)
29022907

2903-
def _filter_commands(
2904-
self, ctx: TyperContext, cmd_filter: t.Optional[t.List[str]] = None
2905-
):
2906-
"""
2907-
Fetch subcommand names. Given a click context, return the list of commands
2908-
that are valid return the list of commands that are valid for the given
2909-
context.
2910-
2911-
:param ctx: the click context
2912-
:param cmd_filter: a list of command names to filter by, if None no subcommands
2913-
will be filtered out
2914-
:return: the list of command names that are valid for the given context
2915-
"""
2916-
return sorted(
2917-
[
2918-
cmd
2919-
for name, cmd in getattr(
2920-
ctx.command,
2921-
"commands",
2922-
{
2923-
name: ctx.command.get_command(ctx, name) # type: ignore[attr-defined]
2924-
for name in (
2925-
ctx.command.list_commands(ctx)
2926-
if isinstance(ctx.command, click.MultiCommand)
2927-
else []
2928-
)
2929-
},
2930-
).items()
2931-
if not cmd_filter or name in cmd_filter
2932-
],
2933-
key=lambda item: item.name,
2934-
)
2935-
2936-
def _build_cmd_tree(
2937-
self,
2938-
cmd: click.Command,
2939-
parent: t.Optional[Context] = None,
2940-
info_name: t.Optional[str] = None,
2941-
node: t.Optional[CommandNode] = None,
2942-
):
2943-
"""
2944-
Recursively build the CommandNode tree used to walk the click command
2945-
hierarchy.
2946-
2947-
:param cmd: the click command to build the tree from
2948-
:param parent: the parent click context
2949-
:param info_name: the name of the command
2950-
:param node: the parent node or None if this is a root node
2951-
"""
2952-
assert cmd.name
2953-
assert isinstance(cmd, DjangoTyperMixin)
2954-
ctx = Context(cmd, info_name=info_name, parent=parent, django_command=self)
2955-
current = CommandNode(cmd.name, cmd, ctx, self, parent=node)
2956-
if node:
2957-
node.children[cmd.name] = current
2958-
for sub_cmd in self._filter_commands(ctx):
2959-
self._build_cmd_tree(
2960-
sub_cmd, parent=ctx, info_name=sub_cmd.name, node=current
2961-
)
2962-
return current
2963-
29642908
def __init_subclass__(cls, **_):
29652909
"""Avoid passing typer arguments up the subclass init chain"""
29662910
return super().__init_subclass__()
@@ -3006,6 +2950,8 @@ def __getattr__(self, name: str) -> t.Any:
30062950
and return that command or group if the attribute name matches the command/group
30072951
function OR its registered CLI name.
30082952
"""
2953+
if isinstance(attr := getattr(self.__class__, name, None), property):
2954+
return t.cast(t.Callable, attr.fget)(self)
30092955
init = getattr(
30102956
self.typer_app.registered_callback,
30112957
"callback",
@@ -3016,7 +2962,8 @@ def __getattr__(self, name: str) -> t.Any:
30162962
found = depth_first_match(self.typer_app, name)
30172963
if found:
30182964
if isinstance(found, Typer):
3019-
# todo shouldn't be needed
2965+
# todo shouldn't be needed - wrap these in a proxy,
2966+
# avoid need for threading local
30202967
found._local.object = self
30212968
else:
30222969
return _get_direct_function(self, found)

django_typer/tests/test_callbacks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from django.core.management import call_command
66
from django.test import TestCase
77

8-
from django_typer import get_command
8+
from django_typer import get_command, TyperCommand
99
from django_typer.tests.utils import run_command
1010

1111

django_typer/tests/test_overloads.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -115,17 +115,19 @@ def test_overloaded_call_command(self):
115115
)
116116
),
117117
{
118-
"samename": {"precision": 1, "flag": True},
119118
"test": {"precision": 5, "flag": False},
119+
"samename": {"precision": 1, "flag": True},
120120
},
121121
)
122+
123+
ret = json.loads(
124+
call_command("overloaded", ["test", "5", "samename", "1"], flag=True)
125+
)
122126
self.assertEqual(
123-
json.loads(
124-
call_command("overloaded", ["test", "5", "samename", "1"], flag=True)
125-
),
127+
ret,
126128
{
127-
"samename": {"precision": 1, "flag": True},
128129
"test": {"precision": 5, "flag": True},
130+
"samename": {"precision": 1, "flag": True},
129131
},
130132
)
131133

0 commit comments

Comments
 (0)