Skip to content

Commit c32aa66

Browse files
committed
make commands callable
1 parent d552c61 commit c32aa66

File tree

4 files changed

+33
-2
lines changed

4 files changed

+33
-2
lines changed

django_typer/__init__.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,10 @@ def handle(
257257
}
258258

259259

260+
class CallableCommand(t.Protocol):
261+
def __call__(self, *args: t.Any, **kwargs: t.Any) -> t.Any: ...
262+
263+
260264
@t.overload # pragma: no cover
261265
def get_command( # type: ignore[overload-overlap]
262266
command_name: str,
@@ -265,7 +269,7 @@ def get_command( # type: ignore[overload-overlap]
265269
no_color: bool = False,
266270
force_color: bool = False,
267271
**kwargs,
268-
) -> BaseCommand: ...
272+
) -> CallableCommand: ...
269273

270274

271275
@t.overload # pragma: no cover
@@ -344,6 +348,13 @@ def get_command(
344348
from myapp.management.commands import Command as Hierarchy
345349
hierarchy: Hierarchy = get_command('hierarchy', Hierarchy)
346350
351+
.. note::
352+
353+
If get_command fetches a BaseCommand that does not implement __call__ get_command will
354+
make the command callable by adding a __call__ method that calls the handle method of
355+
the BaseCommand. This allows you to call the command like get_command("command")() with
356+
confidence.
357+
347358
:param command_name: the name of the command to get
348359
:param path: the path walking down the group/command tree
349360
:param stdout: the stdout stream to use
@@ -367,6 +378,13 @@ def get_command(
367378
if path and (isinstance(path[0], str) or len(path) > 1):
368379
return t.cast(TyperCommand, cmd).get_subcommand(*path).callback
369380

381+
if not hasattr(cmd, "__call__"):
382+
setattr(
383+
cmd.__class__,
384+
"__call__",
385+
lambda self, *args, **options: self.handle(*args, **options),
386+
)
387+
370388
return cmd
371389

372390

django_typer/management/commands/shellcompletion.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -589,7 +589,7 @@ def get_completion() -> None:
589589
cmd_str=cmd_str
590590
)
591591
) from err
592-
return # otherwise nowhere to go - just empty out
592+
raise # otherwise nowhere to go
593593

594594
if isinstance(cmd, TyperCommand): # type: ignore[unreachable]
595595
# this will exit out so no return is needed here
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from django.core.management import BaseCommand
2+
3+
4+
class Command(BaseCommand):
5+
def handle(self, *args, **options):
6+
return f"base({args}, {options})"

django_typer/tests/test_basics.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,3 +118,10 @@ def test_renaming(self):
118118
self.assertEqual(get_command("rename", TyperCommand)(), "handle")
119119
self.assertEqual(get_command("rename", Rename).subcommand1(), "subcommand1")
120120
self.assertEqual(get_command("rename", Rename).subcommand2(), "subcommand2")
121+
122+
def test_get_command_make_callable(self):
123+
args = (1, 2, 3)
124+
kwargs = {"named": "test!"}
125+
self.assertEqual(
126+
get_command("base")(*args, **kwargs), f"base({args}, {kwargs})"
127+
)

0 commit comments

Comments
 (0)