|
82 | 82 | import inspect
|
83 | 83 | import sys
|
84 | 84 | import typing as t
|
| 85 | +from collections import deque |
85 | 86 | from copy import copy, deepcopy
|
86 | 87 | from functools import cached_property
|
87 | 88 | from importlib import import_module
|
@@ -902,16 +903,12 @@ def __get__(self, obj, _=None) -> "Typer[P, R]":
|
902 | 903 | def __getattr__(self, name: str) -> t.Any:
|
903 | 904 | for cmd in self.registered_commands:
|
904 | 905 | assert cmd.callback
|
905 |
| - if name in (cmd.callback.__name__, cmd.name): |
| 906 | + if name in _names(cmd): |
906 | 907 | return cmd
|
907 | 908 | for grp in self.registered_groups:
|
908 | 909 | cmd_grp = t.cast(Typer, grp.typer_instance)
|
909 | 910 | assert cmd_grp
|
910 |
| - if name in ( |
911 |
| - cmd_grp.name, |
912 |
| - grp.name, |
913 |
| - getattr(cmd_grp.info.callback, "__name__", None), |
914 |
| - ): |
| 911 | + if name in _names(cmd_grp): |
915 | 912 | return cmd_grp
|
916 | 913 | raise AttributeError(
|
917 | 914 | "{cls} object has no attribute {name}".format(
|
@@ -1753,37 +1750,76 @@ def _resolve_help(dj_cmd: "TyperCommand"):
|
1753 | 1750 | dj_cmd.typer_app.info.help = hlp
|
1754 | 1751 |
|
1755 | 1752 |
|
1756 |
| -def depth_first_match( |
1757 |
| - app: typer.Typer, name: str |
1758 |
| -) -> t.Optional[t.Union[typer.models.CommandInfo, Typer]]: |
| 1753 | +def _names(tc: t.Union[typer.models.CommandInfo, Typer]) -> t.List[str]: |
1759 | 1754 | """
|
1760 |
| - Perform a depth first search for a command or group by name. |
| 1755 | + For a command or group, get a list of attribute name and its CLI name. |
1761 | 1756 |
|
1762 |
| - TODO - should be breadth first |
| 1757 | + This annoyingly lives in difference places depending on how the command |
| 1758 | + or group was defined. This logic is sensitive to typer internals. |
| 1759 | + """ |
| 1760 | + names = [] |
| 1761 | + if isinstance(tc, typer.models.CommandInfo): |
| 1762 | + assert tc.callback |
| 1763 | + names.append(tc.callback.__name__) |
| 1764 | + if tc.name and tc.name != tc.callback.__name__: |
| 1765 | + names.append(tc.name) |
| 1766 | + else: |
| 1767 | + if tc.name: |
| 1768 | + names.append(tc.name) |
| 1769 | + if tc.info.name and tc.info.name != tc.name: |
| 1770 | + names.append(tc.info.name) |
| 1771 | + cb_name = getattr(tc.info.callback, "__name__", None) |
| 1772 | + if cb_name and cb_name not in names: |
| 1773 | + names.append(cb_name) |
| 1774 | + return names |
| 1775 | + |
| 1776 | + |
| 1777 | +def _bfs_match( |
| 1778 | + app: Typer, name: str |
| 1779 | +) -> t.Optional[t.Union[typer.models.CommandInfo, Typer]]: |
| 1780 | + """ |
| 1781 | + Perform a breadth first search for a command or group by name. |
1763 | 1782 |
|
1764 | 1783 | :param app: The Typer app to search.
|
1765 | 1784 | :param name: The name of the command or group to search for.
|
1766 | 1785 | :return: The command or group if found, otherwise None.
|
1767 | 1786 | """
|
1768 |
| - for cmd in reversed(app.registered_commands): |
1769 |
| - if name in [cmd.name, *([cmd.callback.__name__] if cmd.callback else [])]: |
1770 |
| - return cmd |
1771 |
| - for grp in reversed(app.registered_groups): |
1772 |
| - assert grp.typer_instance |
1773 |
| - grp_app = t.cast(Typer, grp.typer_instance) |
1774 |
| - # some weirdness, grp_app.info not always == grp |
1775 |
| - # todo __deepcopy__ problem? |
1776 |
| - # assert grp_app.info is grp |
1777 |
| - if name in [ |
1778 |
| - grp.name, |
1779 |
| - grp_app.name, |
1780 |
| - getattr(grp_app.info.callback, "__name__", None), |
1781 |
| - ]: |
1782 |
| - return grp_app |
1783 |
| - for grp in reversed(app.registered_groups): |
1784 |
| - assert grp.typer_instance |
1785 |
| - grp_app = t.cast(Typer, grp.typer_instance) |
1786 |
| - found = depth_first_match(grp_app, name) |
| 1787 | + |
| 1788 | + def find_at_level( |
| 1789 | + lvl: Typer, |
| 1790 | + ) -> t.Optional[t.Union[typer.models.CommandInfo, Typer]]: |
| 1791 | + for cmd in reversed(lvl.registered_commands): |
| 1792 | + if name in _names(cmd): |
| 1793 | + return cmd |
| 1794 | + if name in _names(lvl): |
| 1795 | + return lvl |
| 1796 | + return None |
| 1797 | + |
| 1798 | + # fast exit out if at top level (most searches - avoid building BFS) |
| 1799 | + if found := find_at_level(app): |
| 1800 | + return found |
| 1801 | + |
| 1802 | + visited = set() |
| 1803 | + bfs_order: t.List[Typer] = [] |
| 1804 | + queue = deque([app]) |
| 1805 | + |
| 1806 | + while queue: |
| 1807 | + grp = queue.popleft() |
| 1808 | + if grp not in visited: |
| 1809 | + visited.add(grp) |
| 1810 | + bfs_order.append(grp) |
| 1811 | + # if names conflict, only pick the first the others have been |
| 1812 | + # overridden - avoids walking down stale branches |
| 1813 | + seen = [] |
| 1814 | + for child_grp in reversed(grp.registered_groups): |
| 1815 | + child_app = t.cast(Typer, child_grp.typer_instance) |
| 1816 | + assert child_app |
| 1817 | + if child_app not in visited and child_app.name not in seen: |
| 1818 | + seen.extend(_names(child_app)) |
| 1819 | + queue.append(child_app) |
| 1820 | + |
| 1821 | + for grp in bfs_order[1:]: |
| 1822 | + found = find_at_level(grp) |
1787 | 1823 | if found:
|
1788 | 1824 | return found
|
1789 | 1825 | return None
|
@@ -2057,15 +2093,15 @@ def __init__(cls, cls_name, bases, attrs, **kwargs):
|
2057 | 2093 |
|
2058 | 2094 | def __getattr__(cls, name: str) -> t.Any:
|
2059 | 2095 | """
|
2060 |
| - Fall back depth first search of the typer app tree to resolve attribute accesses of the type: |
| 2096 | + Fall back breadth first search of the typer app tree to resolve attribute accesses of the type: |
2061 | 2097 | Command.sub_grp or Command.sub_cmd
|
2062 | 2098 | """
|
2063 | 2099 | if name != "typer_app":
|
2064 | 2100 | if called_from_command_definition():
|
2065 | 2101 | if name in cls._defined_groups:
|
2066 | 2102 | return cls._defined_groups[name]
|
2067 | 2103 | elif cls.typer_app:
|
2068 |
| - found = depth_first_match(cls.typer_app, name) |
| 2104 | + found = _bfs_match(cls.typer_app, name) |
2069 | 2105 | if found:
|
2070 | 2106 | return found
|
2071 | 2107 | raise AttributeError(
|
@@ -2957,7 +2993,7 @@ def __getattr__(self, name: str) -> t.Any:
|
2957 | 2993 | )
|
2958 | 2994 | if init and init and name == init.__name__:
|
2959 | 2995 | return BoundProxy(self, init)
|
2960 |
| - found = depth_first_match(self.typer_app, name) |
| 2996 | + found = _bfs_match(self.typer_app, name) |
2961 | 2997 | if found:
|
2962 | 2998 | return BoundProxy(self, found)
|
2963 | 2999 | raise AttributeError(
|
|
0 commit comments