Skip to content

Commit 321694d

Browse files
committed
more extensions doc work, some typing work
1 parent 0bd5a37 commit 321694d

File tree

14 files changed

+418
-67
lines changed

14 files changed

+418
-67
lines changed

django_typer/__init__.py

Lines changed: 110 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,7 @@
152152

153153
P = ParamSpec("P")
154154
R = t.TypeVar("R")
155+
C = t.TypeVar("C", bound=BaseCommand)
155156

156157
_CACHE_KEY = "_register_typer"
157158

@@ -220,15 +221,52 @@ def handle(
220221
}
221222

222223

224+
@t.overload
225+
def get_command( # type: ignore[overload-overlap]
226+
command_name: str,
227+
stdout: t.Optional[t.IO[str]] = None,
228+
stderr: t.Optional[t.IO[str]] = None,
229+
no_color: bool = False,
230+
force_color: bool = False,
231+
**kwargs,
232+
) -> BaseCommand: ...
233+
234+
235+
@t.overload
236+
# mypy seems to break on this one, but this is correct
223237
def get_command(
224238
command_name: str,
225-
*subcommand: str,
239+
cmd_type: t.Type[C],
226240
stdout: t.Optional[t.IO[str]] = None,
227241
stderr: t.Optional[t.IO[str]] = None,
228242
no_color: bool = False,
229243
force_color: bool = False,
230244
**kwargs,
231-
) -> t.Union[BaseCommand, MethodType]:
245+
) -> C: ...
246+
247+
248+
@t.overload
249+
def get_command(
250+
command_name: str,
251+
path0: str,
252+
*path: str,
253+
stdout: t.Optional[t.IO[str]] = None,
254+
stderr: t.Optional[t.IO[str]] = None,
255+
no_color: bool = False,
256+
force_color: bool = False,
257+
**kwargs,
258+
) -> MethodType: ...
259+
260+
261+
def get_command(
262+
command_name,
263+
*path,
264+
stdout=None,
265+
stderr=None,
266+
no_color: bool = False,
267+
force_color: bool = False,
268+
**kwargs,
269+
):
232270
"""
233271
Get a Django_ command by its name and instantiate it with the provided options. This
234272
will work for subclasses of BaseCommand_ as well as for :class:`~django_typer.TyperCommand`
@@ -261,8 +299,17 @@ def get_command(
261299
divide = get_command('hierarchy', 'math', 'divide')
262300
result = divide(10, 2)
263301
302+
When fetching an entire TyperCommand (i.e. no group or subcommand path), you may supply
303+
the type of the expected TyperCommand as the second argument. This will allow the type
304+
system to infer the correct return type:
305+
306+
.. code-block:: python
307+
308+
from myapp.management.commands import Command as Hierarchy
309+
hierarchy: Hierarchy = get_command('hierarchy', Hierarchy)
310+
264311
:param command_name: the name of the command to get
265-
:param subcommand: the subcommand to get if any
312+
:param path: the path walking down the group/command tree
266313
:param stdout: the stdout stream to use
267314
:param stderr: the stderr stream to use
268315
:param no_color: whether to disable color
@@ -274,16 +321,15 @@ def get_command(
274321
module = import_module(
275322
f"{get_commands()[command_name]}.management.commands.{command_name}"
276323
)
277-
cmd = module.Command(
324+
cmd: BaseCommand = module.Command(
278325
stdout=stdout,
279326
stderr=stderr,
280327
no_color=no_color,
281328
force_color=force_color,
282329
**kwargs,
283330
)
284-
if subcommand:
285-
method = cmd.get_subcommand(*subcommand).click_command._callback.__wrapped__
286-
return MethodType(method, cmd) # return the bound method
331+
if path and (isinstance(path[0], str) or len(path) > 1):
332+
return t.cast(TyperCommand, cmd).get_subcommand(*path).callback
287333

288334
return cmd
289335

@@ -406,7 +452,7 @@ def __init__(
406452
parent.children.append(self)
407453

408454

409-
class _DjangoAdapterMixin(with_typehint(CoreTyperGroup)): # type: ignore[misc]
455+
class DjangoTyperCommand(with_typehint(CoreTyperGroup)): # type: ignore[misc]
410456
"""
411457
A mixin we use to add additional needed contextual awareness to click Commands
412458
and Groups.
@@ -556,15 +602,15 @@ def call_with_self(*args, **kwargs):
556602
)
557603

558604

559-
class TyperCommandWrapper(_DjangoAdapterMixin, CoreTyperCommand):
605+
class TyperCommandWrapper(DjangoTyperCommand, CoreTyperCommand):
560606
"""
561607
This class extends the TyperCommand class to work with the django-typer
562608
interfaces. If you need to add functionality to the command class - which
563609
you should not - you should inherit from this class.
564610
"""
565611

566612

567-
class TyperGroupWrapper(_DjangoAdapterMixin, CoreTyperGroup):
613+
class TyperGroupWrapper(DjangoTyperCommand, CoreTyperGroup):
568614
"""
569615
This class extends the TyperGroup class to work with the django-typer
570616
interfaces. If you need to add functionality to the group class - which
@@ -2162,16 +2208,47 @@ class CommandNode:
21622208
"""
21632209

21642210
name: str
2165-
click_command: click.Command
2211+
"""
2212+
The name of the group or command that this node represents.
2213+
"""
2214+
2215+
click_command: DjangoTyperCommand
2216+
"""
2217+
The click command object that this node represents.
2218+
"""
2219+
21662220
context: TyperContext
2221+
"""
2222+
The Typer context object used to run this command.
2223+
"""
2224+
21672225
django_command: "TyperCommand"
2226+
"""
2227+
Back reference to the django command instance that this command belongs to.
2228+
"""
2229+
21682230
parent: t.Optional["CommandNode"] = None
2231+
"""
2232+
The parent node of this command node or None if this is a root node.
2233+
"""
2234+
21692235
children: t.Dict[str, "CommandNode"]
2236+
"""
2237+
The child group and command nodes of this command node.
2238+
"""
2239+
2240+
@property
2241+
def callback(self) -> t.Callable[..., t.Any]:
2242+
"""Get the function for this command or group"""
2243+
cb = getattr(self.click_command._callback, "__wrapped__")
2244+
return (
2245+
MethodType(cb, self.django_command) if self.click_command.is_method else cb
2246+
)
21702247

21712248
def __init__(
21722249
self,
21732250
name: str,
2174-
click_command: click.Command,
2251+
click_command: DjangoTyperCommand,
21752252
context: TyperContext,
21762253
django_command: "TyperCommand",
21772254
parent: t.Optional["CommandNode"] = None,
@@ -2197,8 +2274,9 @@ def get_command(self, *command_path: str) -> "CommandNode":
21972274
Return the command node for the given command path at or below
21982275
this node.
21992276
2200-
:param command_path: the path(s) to the command to retrieve
2201-
:return: the command node at the given path
2277+
:param command_path: the parent group names followed by the name of the command
2278+
or group to retrieve
2279+
:return: the command node at the given group/subcommand path
22022280
:raises LookupError: if the command path does not exist
22032281
"""
22042282
if not command_path:
@@ -2208,6 +2286,15 @@ def get_command(self, *command_path: str) -> "CommandNode":
22082286
except KeyError as err:
22092287
raise LookupError(f'No such command "{command_path[0]}"') from err
22102288

2289+
def __call__(self, *args, **kwargs) -> t.Any:
2290+
"""
2291+
Call this command or group directly.
2292+
2293+
:param args: the arguments to pass to the command or group callback
2294+
:param kwargs: the named parameters to pass to the command or group callback
2295+
"""
2296+
return self.callback(*args, **kwargs)
2297+
22112298

22122299
class TyperParser:
22132300
"""
@@ -2902,7 +2989,14 @@ def __init__(
29022989
) from rerr
29032990

29042991
def get_subcommand(self, *command_path: str) -> CommandNode:
2905-
"""Get the CommandNode"""
2992+
"""
2993+
Retrieve a :class:`~django_typer.CommandNode` at the given command path.
2994+
2995+
:param command_path: the path to the command to retrieve, where each argument
2996+
is the string name in order of a group or command in the hierarchy.
2997+
:return: the command node at the given path
2998+
:raises LookupError: if no group or command exists at the given path
2999+
"""
29063000
return self.command_tree.get_command(*command_path)
29073001

29083002
def _filter_commands(
@@ -2955,6 +3049,7 @@ def _build_cmd_tree(
29553049
:param node: the parent node or None if this is a root node
29563050
"""
29573051
assert cmd.name
3052+
assert isinstance(cmd, DjangoTyperCommand)
29583053
ctx = Context(cmd, info_name=info_name, parent=parent, django_command=self)
29593054
current = CommandNode(cmd.name, cmd, ctx, self, parent=node)
29603055
if node:

django_typer/examples/tutorial/backup/backup.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,13 @@
77
from django.conf import settings
88
from django.core.management import CommandError, call_command
99

10-
from django_typer import TyperCommand, command, completers, initialize
10+
from django_typer import (
11+
TyperCommand,
12+
CommandNode,
13+
command,
14+
completers,
15+
initialize,
16+
)
1117

1218

1319
class Command(TyperCommand):
@@ -29,7 +35,7 @@ class Command(TyperCommand):
2935
@initialize(invoke_without_command=True)
3036
def init_or_run_all(
3137
self,
32-
# if we add a context argument click will provide it
38+
# if we add a context argument Typer will provide it
3339
context: typer.Context,
3440
output_directory: t.Annotated[
3541
Path,
@@ -53,9 +59,9 @@ def init_or_run_all(
5359
# if it was not we run all the backup routines
5460
if not context.invoked_subcommand:
5561
for cmd in self.get_backup_routines():
56-
getattr(self, cmd)()
62+
cmd()
5763

58-
def get_backup_routines(self) -> t.List[str]:
64+
def get_backup_routines(self) -> t.List[CommandNode]:
5965
"""
6066
Return the list of backup subcommands. This is every registered command
6167
except for the list command.
@@ -64,8 +70,8 @@ def get_backup_routines(self) -> t.List[str]:
6470
# except for list, which we know to not be a backup routine
6571
return [
6672
cmd
67-
for cmd in self.get_subcommand().children.keys()
68-
if cmd != "list"
73+
for name, cmd in self.get_subcommand().children.items()
74+
if name != "list"
6975
]
7076

7177
@command()
@@ -75,14 +81,15 @@ def list(self):
7581
"""
7682
self.echo("Default backup routines:")
7783
for cmd in self.get_backup_routines():
78-
kwargs = {
79-
name: str(param.default)
84+
sig = {
85+
name: param.default
8086
for name, param in inspect.signature(
81-
getattr(self, cmd)
87+
cmd.callback
8288
).parameters.items()
89+
if not name == "self"
8390
}
84-
params = ", ".join([f"{k}={v}" for k, v in kwargs.items()])
85-
self.secho(f" {cmd}({params})", fg="green")
91+
params = ", ".join([f"{k}={v}" for k, v in sig.items()])
92+
self.secho(f" {cmd.name}({params})", fg="green")
8693

8794
@command()
8895
def database(

django_typer/examples/tutorial/backup/backup_ext1.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,13 @@
99
Command as Backup,
1010
)
1111

12-
DEFAULT_MEDIA_FILENAME = "media.tar.gz"
13-
1412

13+
# instead of inheriting we add the command using the classmethod decorator
14+
# on the backup Command class to decorate a module scoped function
1515
@Backup.command()
1616
def media(
17+
# self is optional, but if you want to access the command instance, you
18+
# can specify it
1719
self,
1820
filename: t.Annotated[
1921
str,
@@ -22,7 +24,7 @@ def media(
2224
"--filename",
2325
help=("The name of the file to use for the media backup tar."),
2426
),
25-
] = DEFAULT_MEDIA_FILENAME,
27+
] = "media.tar.gz",
2628
):
2729
"""
2830
Backup the media files (i.e. those files in MEDIA_ROOT).

django_typer/examples/tutorial/backup/backup_ext2.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
1+
import shutil
12
import subprocess
23
import typing as t
4+
import datetime
35

46
import typer
57

8+
from django.conf import settings
69
from django_typer.tests.apps.backup.backup.management.commands.backup import (
710
Command as Backup,
811
)
@@ -29,3 +32,17 @@ def environment(
2932
typer.echo(f"Capturing python environment to {output_file}")
3033
with output_file.open("w") as f:
3134
subprocess.run(["pip", "freeze"], stdout=f)
35+
36+
37+
@Backup.command()
38+
def database(self):
39+
"""
40+
Backup the database by copying the sqlite file and tagging it with the
41+
current date.
42+
"""
43+
db_file = self.output_directory / f"backup_{datetime.date.today()}.sqlite3"
44+
self.echo("Backing up database to {db_file}")
45+
shutil.copy(
46+
settings.DATABASES["default"]["NAME"],
47+
db_file,
48+
)

django_typer/examples/tutorial/backup/backup_inherit.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111
)
1212

1313

14-
class Command(Backup):
14+
class Command(Backup): # inherit from the original command
15+
# add a new command called media that archives the MEDIA_ROOT dir
1516
@command()
1617
def media(
1718
self,

django_typer/tests/apps/backup/extend2/apps.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,4 @@ class Extend2Config(AppConfig):
88

99
def ready(self):
1010
from .management import extensions
11-
1211
register_command_extensions(extensions)

django_typer/tests/test_backup_example.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from django.test import SimpleTestCase
66
from pathlib import Path
77
from django_typer.tests.utils import run_command
8+
import datetime
89
import shutil
910

1011

@@ -93,9 +94,7 @@ def test_extend_backup(self):
9394
self.assertEqual(retcode, 0, msg=stderr)
9495
lines = [line.strip() for line in stdout.strip().splitlines()[1:]]
9596
self.assertEqual(len(lines), 3)
96-
self.assertTrue(
97-
"database(filename={database}.json, databases=['default'])" in lines
98-
)
97+
self.assertTrue("database()" in lines)
9998
self.assertTrue("environment(filename=requirements.txt)" in lines)
10099
self.assertTrue("media(filename=media.tar.gz)" in lines)
101100

@@ -108,7 +107,9 @@ def test_extend_backup(self):
108107
)
109108
self.assertEqual(retcode, 0, msg=stderr)
110109
self.assertTrue(BACKUP_DIRECTORY.exists())
111-
self.assertTrue((BACKUP_DIRECTORY / "default.json").exists())
110+
self.assertTrue(
111+
(BACKUP_DIRECTORY / f"backup_{datetime.date.today()}.sqlite3").exists()
112+
)
112113
self.assertTrue((BACKUP_DIRECTORY / "media.tar.gz").exists())
113114
self.assertTrue((BACKUP_DIRECTORY / "requirements.txt").exists())
114115
self.assertTrue(len(os.listdir(BACKUP_DIRECTORY)) == 3)

0 commit comments

Comments
 (0)