Skip to content

Commit fe4613b

Browse files
committed
make stack introspection more bulletproof
1 parent c32aa66 commit fe4613b

File tree

5 files changed

+38
-12
lines changed

5 files changed

+38
-12
lines changed

django_typer/__init__.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,7 @@
124124
)
125125
from .utils import (
126126
_command_context,
127+
called_from_command_definition,
127128
called_from_module,
128129
get_usage_script,
129130
is_method,
@@ -258,7 +259,7 @@ def handle(
258259

259260

260261
class CallableCommand(t.Protocol):
261-
def __call__(self, *args: t.Any, **kwargs: t.Any) -> t.Any: ...
262+
def __call__(self, *args: t.Any, **kwargs: t.Any) -> t.Any: ... # pragma: no cover
262263

263264

264265
@t.overload # pragma: no cover
@@ -698,6 +699,10 @@ def grp(self):
698699
"""
699700

700701

702+
# staticmethod objects are not picklable which causes problems with deepcopy
703+
# hence the following mishegoss
704+
705+
701706
@t.overload # pragma: no cover
702707
def _check_static(
703708
func: typer.models.CommandFunctionType,
@@ -710,7 +715,8 @@ def _check_static(func: None) -> None: ...
710715

711716
def _check_static(func):
712717
"""
713-
Check if a function is a staticmethod and return it if it is.
718+
Check if a function is a staticmethod and return it if it is otherwise make
719+
it static if it should be but isn't.
714720
"""
715721
if func and not is_method(func) and not isinstance(func, staticmethod):
716722
return staticmethod(func)
@@ -805,6 +811,11 @@ def _get_direct_function(
805811

806812

807813
class AppFactory(type):
814+
"""
815+
A metaclass used to define/set Command classes into the defining module when
816+
the Typer-like functional interface is used.
817+
"""
818+
808819
def __call__(self, *args, **kwargs) -> "Typer":
809820
if called_from_module():
810821
frame = inspect.currentframe()
@@ -819,6 +830,7 @@ class Command(
819830
):
820831
pass
821832

833+
Command.__module__ = cmd_module.__name__ # spoof it hard
822834
setattr(cmd_module, "Command", Command)
823835
return Command.typer_app
824836
else:
@@ -2066,7 +2078,7 @@ def __getattr__(cls, name: str) -> t.Any:
20662078
Command.sub_grp or Command.sub_cmd
20672079
"""
20682080
if name != "typer_app":
2069-
if not called_from_module():
2081+
if called_from_command_definition():
20702082
if name in cls._defined_groups:
20712083
return cls._defined_groups[name]
20722084
elif cls.typer_app:
@@ -2528,7 +2540,7 @@ def init(self, ...):
25282540
:param rich_help_panel: the rich help panel to use - if rich is installed
25292541
this can be used to group commands into panels in the help output.
25302542
"""
2531-
if not called_from_module():
2543+
if called_from_command_definition():
25322544
return initialize(
25332545
name=name,
25342546
cls=cls,
@@ -2632,7 +2644,7 @@ def new_command(self, ...):
26322644
:param rich_help_panel: the rich help panel to use - if rich is installed
26332645
this can be used to group commands into panels in the help output.
26342646
"""
2635-
if not called_from_module():
2647+
if called_from_command_definition():
26362648
return command(
26372649
name=name,
26382650
cls=cls,
@@ -2739,7 +2751,7 @@ def grp_command(self, ...):
27392751
:param rich_help_panel: the rich help panel to use - if rich is installed
27402752
this can be used to group commands into panels in the help output.
27412753
"""
2742-
if not called_from_module():
2754+
if called_from_command_definition():
27432755
return group(
27442756
name=name,
27452757
cls=cls,

django_typer/tests/apps/test_app/management/commands/chain.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def command1(self, option: t.Optional[str] = None):
2121
print("command1")
2222
return option
2323

24-
@command()
24+
@TyperCommand.command()
2525
def command2(self, option: t.Optional[str] = None):
2626
"""This is a *markdown* help string"""
2727
print("command2")

django_typer/tests/test_adapter_pattern.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1082,7 +1082,7 @@ def test_group_adapter_precedence(self):
10821082
# adapted2.__class__.subsub_grp2(),
10831083
# "adapter2::adapted2()::grp2()::sub_grp2()::subsub_grp2()",
10841084
# )
1085-
self.assertFalse(hasattr(adapted2.__class__, "subsub_grp2"))
1085+
self.assertTrue(hasattr(adapted2.__class__, "subsub_grp2"))
10861086
self.assertEqual(
10871087
adapted2.subsub_grp2(),
10881088
"adapter2::adapted2()::grp2()::sub_grp2()::subsub_grp2()",

django_typer/tests/test_interface.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,7 @@ def test_action_nargs(self):
262262
self.assertEqual(multi_parser._actions[8].nargs, 0)
263263

264264
def test_cmd_getattr(self):
265+
from django_typer import TyperCommand
265266
from django_typer.tests.apps.test_app.management.commands.groups import (
266267
Command as Groups,
267268
)
@@ -278,3 +279,9 @@ def test_cmd_getattr(self):
278279
self.assertFalse(True, "should have thrown AttributeError")
279280
except AttributeError as e:
280281
pass
282+
283+
try:
284+
TyperCommand.does_not_exist
285+
self.assertFalse(True, "should have thrown AttributeError")
286+
except AttributeError as e:
287+
pass

django_typer/utils.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import shutil
1010
import sys
1111
import typing as t
12+
from functools import partial
1213
from pathlib import Path
1314
from threading import local
1415
from types import MethodType, ModuleType
@@ -163,21 +164,27 @@ def load_command_extensions(command: str) -> int:
163164
return len(extensions)
164165

165166

166-
def called_from_module() -> bool:
167+
def _check_call_frame(frame_name: str) -> bool:
167168
"""
168-
Returns True if the stack frame one frame above where this function is called was at module
169-
scope. Regrettable interface simplifying voodoo. This is, at least, reliable.
169+
Returns True if the stack frame one frame above where this function has the given
170+
name.
171+
172+
:param frame_name: The name of the frame to check for
170173
"""
171174
frame = inspect.currentframe()
172175
for _ in range(0, 2):
173176
if not frame:
174177
break
175178
frame = frame.f_back
176179
if frame:
177-
return frame.f_code.co_name == "<module>"
180+
return frame.f_code.co_name == frame_name
178181
return False
179182

180183

184+
called_from_module = partial(_check_call_frame, "<module>")
185+
called_from_command_definition = partial(_check_call_frame, "Command")
186+
187+
181188
def is_method(
182189
func_or_params: t.Optional[t.Union[t.Callable[..., t.Any], t.List[str]]],
183190
) -> t.Optional[bool]:

0 commit comments

Comments
 (0)