83
83
import threading
84
84
import typing as t
85
85
from copy import copy , deepcopy
86
+ from functools import cached_property
86
87
from importlib import import_module
87
88
from pathlib import Path
88
89
from types import MethodType , SimpleNamespace
@@ -270,7 +271,7 @@ def get_command( # type: ignore[overload-overlap]
270
271
no_color : bool = False ,
271
272
force_color : bool = False ,
272
273
** kwargs ,
273
- ) -> CallableCommand : ...
274
+ ) -> BaseCommand : ...
274
275
275
276
276
277
@t .overload # pragma: no cover
@@ -794,7 +795,12 @@ def register(
794
795
795
796
def _get_direct_function (
796
797
obj : "TyperCommand" ,
797
- app_node : t .Union ["Typer" , typer .models .CommandInfo , typer .models .TyperInfo ],
798
+ app_node : t .Union [
799
+ "Typer" ,
800
+ typer .models .CommandInfo ,
801
+ typer .models .TyperInfo ,
802
+ t .Callable [..., t .Any ],
803
+ ],
798
804
):
799
805
"""
800
806
Get a direct callable function bound to the given object if it is not static held by the given
@@ -803,8 +809,11 @@ def _get_direct_function(
803
809
if isinstance (app_node , Typer ):
804
810
method = app_node .is_method
805
811
cb = getattr (app_node .registered_callback , "callback" , app_node .info .callback )
812
+ elif cb := getattr (app_node , "callback" , None ):
813
+ method = is_method (cb )
806
814
else :
807
- cb = app_node .callback
815
+ assert callable (app_node )
816
+ cb = app_node
808
817
method = is_method (cb )
809
818
assert cb
810
819
return MethodType (cb , obj ) if method else staticmethod (cb )
@@ -886,36 +895,10 @@ def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R:
886
895
return _get_direct_function (cmd , self )(* args , ** kwargs )
887
896
return super ().__call__ (* args , ** kwargs )
888
897
889
- @t .overload # pragma: no cover
890
- def __get__ (
891
- self , obj : "TyperCommandMeta" , owner : t .Type ["TyperCommandMeta" ]
892
- ) -> "Typer[P, R]" : ...
893
-
894
- @t .overload # pragma: no cover
895
- def __get__ (
896
- self ,
897
- obj : None ,
898
- owner : t .Type ["TyperCommand" ],
899
- ) -> "Typer[P, R]" : ...
900
-
901
- @t .overload # pragma: no cover
902
- def __get__ (
903
- self ,
904
- obj : "Typer" ,
905
- owner : t .Type ["Typer" ],
906
- ) -> "Typer[P, R]" : ...
907
-
908
- @t .overload # pragma: no cover
909
- def __get__ (
910
- self , obj : "TyperCommand" , owner : t .Any = None
911
- ) -> "Typer[P, R]" : # t.Union[MethodType, t.Callable[P, R]]
912
- # todo - we could return the generic callable type here but the problem
913
- # is self is included in the ParamSpec and it seems tricky to remove?
914
- # MethodType loses the parameters but is preferable to type checking errors
915
- # https://github.com/bckohan/django-typer/issues/73
916
- ...
917
-
918
- def __get__ (self , obj , owner = None ): # pyright: ignore[reportInconsistentOverload]
898
+ # todo - this results in type hinting expecting self to be passed explicitly
899
+ # when this is called as a callable
900
+ # https://github.com/bckohan/django-typer/issues/73
901
+ def __get__ (self , obj , _ = None ) -> "Typer[P, R]" :
919
902
"""
920
903
Our Typer app wrapper also doubles as a descriptor, so when
921
904
it is accessed on the instance, we return the wrapped function
@@ -2113,7 +2096,7 @@ class CommandNode:
2113
2096
The click command object that this node represents.
2114
2097
"""
2115
2098
2116
- context : TyperContext
2099
+ context : Context
2117
2100
"""
2118
2101
The Typer context object used to run this command.
2119
2102
"""
@@ -2123,38 +2106,50 @@ class CommandNode:
2123
2106
Back reference to the django command instance that this command belongs to.
2124
2107
"""
2125
2108
2126
- parent : t .Optional ["CommandNode" ] = None
2127
- """
2128
- The parent node of this command node or None if this is a root node.
2129
- """
2130
-
2131
- children : t .Dict [str , "CommandNode" ]
2132
- """
2133
- The child group and command nodes of this command node.
2134
- """
2109
+ @cached_property
2110
+ def children (self ) -> t .Dict [str , "CommandNode" ]:
2111
+ """
2112
+ The child group and command nodes of this command node.
2113
+ """
2114
+ return {
2115
+ name : CommandNode (name , cmd , self .django_command , parent = self .context )
2116
+ for name , cmd in getattr (
2117
+ self .context .command ,
2118
+ "commands" ,
2119
+ {
2120
+ name : self .context .command .get_command (self .context , name ) # type: ignore[attr-defined]
2121
+ for name in (
2122
+ self .context .command .list_commands (self .context )
2123
+ if isinstance (self .context .command , click .MultiCommand )
2124
+ else []
2125
+ )
2126
+ },
2127
+ ).items ()
2128
+ }
2135
2129
2136
2130
@property
2137
2131
def callback (self ) -> t .Callable [..., t .Any ]:
2138
2132
"""Get the function for this command or group"""
2139
- cb = getattr (self .click_command ._callback , "__wrapped__" )
2140
- return (
2141
- MethodType (cb , self .django_command ) if self .click_command .is_method else cb
2133
+ return _get_direct_function (
2134
+ self .django_command , getattr (self .click_command ._callback , "__wrapped__" )
2142
2135
)
2143
2136
2144
2137
def __init__ (
2145
2138
self ,
2146
2139
name : str ,
2147
2140
click_command : DjangoTyperMixin ,
2148
- context : TyperContext ,
2149
2141
django_command : "TyperCommand" ,
2150
- parent : t .Optional ["CommandNode" ] = None ,
2142
+ parent : t .Optional [Context ] = None ,
2151
2143
):
2152
2144
self .name = name
2153
2145
self .click_command = click_command
2154
- self .context = context
2155
2146
self .django_command = django_command
2156
- self .parent = parent
2157
- self .children = {}
2147
+ self .context = Context (
2148
+ self .click_command ,
2149
+ info_name = name ,
2150
+ django_command = django_command ,
2151
+ parent = parent ,
2152
+ )
2158
2153
2159
2154
def print_help (self ) -> t .Optional [str ]:
2160
2155
"""
@@ -2235,9 +2230,12 @@ def nargs(self) -> int:
2235
2230
@property
2236
2231
def option_strings (self ) -> t .List [str ]:
2237
2232
"""
2238
- The list of allowable command line option strings for this parameter.
2233
+ call_command uses this to determine a mapping of supplied options to function
2234
+ arguments. I.e. it will remap option_string: dest. We don't want this because
2235
+ we'd rather have supplied parameters line up with their function arguments to
2236
+ allow deconfliction when CLI options share the same name.
2239
2237
"""
2240
- return list ( self . param . opts ) if isinstance ( self . param , click . Option ) else []
2238
+ return []
2241
2239
2242
2240
_actions : t .List [t .Any ]
2243
2241
_mutually_exclusive_groups : t .List [t .Any ] = []
@@ -2251,23 +2249,22 @@ def __init__(self, django_command: "TyperCommand", prog_name, subcommand):
2251
2249
self .django_command = django_command
2252
2250
self .prog_name = prog_name
2253
2251
self .subcommand = subcommand
2252
+ self .tree = self .django_command .command_tree
2253
+ self .tree .context .info_name = f"{ self .prog_name } { self .subcommand } "
2254
2254
2255
2255
def populate_params (node : CommandNode ) -> None :
2256
2256
for param in node .click_command .params :
2257
2257
self ._actions .append (self .Action (param ))
2258
2258
for child in node .children .values ():
2259
2259
populate_params (child )
2260
2260
2261
- populate_params (self .django_command . command_tree )
2261
+ populate_params (self .tree )
2262
2262
2263
2263
def print_help (self , * command_path : str ):
2264
2264
"""
2265
2265
Print the help for the given command path to stdout of the django command.
2266
2266
"""
2267
- self .django_command .command_tree .context .info_name = (
2268
- f"{ self .prog_name } { self .subcommand } "
2269
- )
2270
- command_node = self .django_command .get_subcommand (* command_path )
2267
+ command_node = self .tree .get_command (* command_path )
2271
2268
hlp = command_node .print_help ()
2272
2269
if hlp :
2273
2270
self .django_command .stdout .write (
@@ -2470,12 +2467,22 @@ def command2(self, option: t.Optional[str] = None):
2470
2467
2471
2468
help : t .Optional [t .Union [DefaultPlaceholder , str ]] = Default (None ) # type: ignore
2472
2469
2473
- command_tree : CommandNode
2474
-
2475
2470
# allow deriving commands to override handle() from BaseCommand
2476
2471
# without triggering static type checking complaints
2477
2472
handle = None # type: ignore
2478
2473
2474
+ @property
2475
+ def command_tree (self ) -> CommandNode :
2476
+ """
2477
+ Get the root CommandNode for this command. Allows easy traversal of the command
2478
+ tree.
2479
+ """
2480
+ return CommandNode (
2481
+ f"{ sys .argv [0 ]} { self ._name } " ,
2482
+ click_command = t .cast (DjangoTyperMixin , get_typer_command (self .typer_app )),
2483
+ django_command = self ,
2484
+ )
2485
+
2479
2486
@classmethod
2480
2487
def initialize (
2481
2488
cmd , # pyright: ignore[reportSelfClsParameterName]
@@ -2879,9 +2886,7 @@ def __init__(
2879
2886
self .stdout .style_func = stdout_style_func
2880
2887
self .stderr .style_func = stderr_style_func
2881
2888
try :
2882
- self .command_tree = self ._build_cmd_tree (
2883
- get_typer_command (self .typer_app )
2884
- )
2889
+ assert get_typer_command (self .typer_app )
2885
2890
except RuntimeError as rerr :
2886
2891
raise NotImplementedError (
2887
2892
_ (
@@ -2900,67 +2905,6 @@ def get_subcommand(self, *command_path: str) -> CommandNode:
2900
2905
"""
2901
2906
return self .command_tree .get_command (* command_path )
2902
2907
2903
- def _filter_commands (
2904
- self , ctx : TyperContext , cmd_filter : t .Optional [t .List [str ]] = None
2905
- ):
2906
- """
2907
- Fetch subcommand names. Given a click context, return the list of commands
2908
- that are valid return the list of commands that are valid for the given
2909
- context.
2910
-
2911
- :param ctx: the click context
2912
- :param cmd_filter: a list of command names to filter by, if None no subcommands
2913
- will be filtered out
2914
- :return: the list of command names that are valid for the given context
2915
- """
2916
- return sorted (
2917
- [
2918
- cmd
2919
- for name , cmd in getattr (
2920
- ctx .command ,
2921
- "commands" ,
2922
- {
2923
- name : ctx .command .get_command (ctx , name ) # type: ignore[attr-defined]
2924
- for name in (
2925
- ctx .command .list_commands (ctx )
2926
- if isinstance (ctx .command , click .MultiCommand )
2927
- else []
2928
- )
2929
- },
2930
- ).items ()
2931
- if not cmd_filter or name in cmd_filter
2932
- ],
2933
- key = lambda item : item .name ,
2934
- )
2935
-
2936
- def _build_cmd_tree (
2937
- self ,
2938
- cmd : click .Command ,
2939
- parent : t .Optional [Context ] = None ,
2940
- info_name : t .Optional [str ] = None ,
2941
- node : t .Optional [CommandNode ] = None ,
2942
- ):
2943
- """
2944
- Recursively build the CommandNode tree used to walk the click command
2945
- hierarchy.
2946
-
2947
- :param cmd: the click command to build the tree from
2948
- :param parent: the parent click context
2949
- :param info_name: the name of the command
2950
- :param node: the parent node or None if this is a root node
2951
- """
2952
- assert cmd .name
2953
- assert isinstance (cmd , DjangoTyperMixin )
2954
- ctx = Context (cmd , info_name = info_name , parent = parent , django_command = self )
2955
- current = CommandNode (cmd .name , cmd , ctx , self , parent = node )
2956
- if node :
2957
- node .children [cmd .name ] = current
2958
- for sub_cmd in self ._filter_commands (ctx ):
2959
- self ._build_cmd_tree (
2960
- sub_cmd , parent = ctx , info_name = sub_cmd .name , node = current
2961
- )
2962
- return current
2963
-
2964
2908
def __init_subclass__ (cls , ** _ ):
2965
2909
"""Avoid passing typer arguments up the subclass init chain"""
2966
2910
return super ().__init_subclass__ ()
@@ -3006,6 +2950,8 @@ def __getattr__(self, name: str) -> t.Any:
3006
2950
and return that command or group if the attribute name matches the command/group
3007
2951
function OR its registered CLI name.
3008
2952
"""
2953
+ if isinstance (attr := getattr (self .__class__ , name , None ), property ):
2954
+ return t .cast (t .Callable , attr .fget )(self )
3009
2955
init = getattr (
3010
2956
self .typer_app .registered_callback ,
3011
2957
"callback" ,
@@ -3016,7 +2962,8 @@ def __getattr__(self, name: str) -> t.Any:
3016
2962
found = depth_first_match (self .typer_app , name )
3017
2963
if found :
3018
2964
if isinstance (found , Typer ):
3019
- # todo shouldn't be needed
2965
+ # todo shouldn't be needed - wrap these in a proxy,
2966
+ # avoid need for threading local
3020
2967
found ._local .object = self
3021
2968
else :
3022
2969
return _get_direct_function (self , found )
0 commit comments