Skip to content

Commit f14695b

Browse files
committed
binding iterator
1 parent 8256ca6 commit f14695b

File tree

5 files changed

+64
-37
lines changed

5 files changed

+64
-37
lines changed

src/textual/app.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@
8080
from .actions import ActionParseResult, SkipAction
8181
from .await_complete import AwaitComplete
8282
from .await_remove import AwaitRemove
83-
from .binding import Binding, BindingType, _Bindings
83+
from .binding import Binding, BindingsMap, BindingType
8484
from .command import CommandPalette, Provider
8585
from .css.errors import StylesheetError
8686
from .css.query import NoMatches
@@ -3000,14 +3000,14 @@ def bell(self) -> None:
30003000
self._driver.write("\07")
30013001

30023002
@property
3003-
def _binding_chain(self) -> list[tuple[DOMNode, _Bindings]]:
3003+
def _binding_chain(self) -> list[tuple[DOMNode, BindingsMap]]:
30043004
"""Get a chain of nodes and bindings to consider.
30053005
30063006
If no widget is focused, returns the bindings from both the screen and the app level bindings.
30073007
Otherwise, combines all the bindings from the currently focused node up the DOM to the root App.
30083008
"""
30093009
focused = self.focused
3010-
namespace_bindings: list[tuple[DOMNode, _Bindings]]
3010+
namespace_bindings: list[tuple[DOMNode, BindingsMap]]
30113011

30123012
if focused is None:
30133013
namespace_bindings = [

src/textual/binding.py

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from __future__ import annotations
99

1010
from dataclasses import dataclass
11-
from typing import TYPE_CHECKING, Iterable, NamedTuple
11+
from typing import TYPE_CHECKING, Iterable, Iterator, NamedTuple
1212

1313
import rich.repr
1414

@@ -64,7 +64,7 @@ class ActiveBinding(NamedTuple):
6464

6565

6666
@rich.repr.auto
67-
class _Bindings:
67+
class BindingsMap:
6868
"""Manage a set of bindings."""
6969

7070
def __init__(
@@ -118,21 +118,39 @@ def make_bindings(bindings: Iterable[BindingType]) -> Iterable[Binding]:
118118
for binding in make_bindings(bindings or {}):
119119
self.keys.setdefault(binding.key, []).append(binding)
120120

121-
def copy(self) -> _Bindings:
121+
def __iter__(self) -> Iterator[tuple[str, Binding]]:
122+
return iter(
123+
[
124+
(key, bindings)
125+
for key, bindings in self.keys.items()
126+
for binding in bindings
127+
]
128+
)
129+
# for key, bindings in self.keys.items():
130+
# for binding in bindings:
131+
# yield key, binding
132+
133+
@classmethod
134+
def from_keys(cls, keys: dict[str, list[Binding]]) -> BindingsMap:
135+
bindings = cls()
136+
bindings.keys = keys
137+
return bindings
138+
139+
def copy(self) -> BindingsMap:
122140
"""Return a copy of this instance.
123141
124142
Return:
125143
New bindings object.
126144
"""
127-
copy = _Bindings()
145+
copy = BindingsMap()
128146
copy.keys = self.keys.copy()
129147
return copy
130148

131149
def __rich_repr__(self) -> rich.repr.Result:
132150
yield self.keys
133151

134152
@classmethod
135-
def merge(cls, bindings: Iterable[_Bindings]) -> _Bindings:
153+
def merge(cls, bindings: Iterable[BindingsMap]) -> BindingsMap:
136154
"""Merge a bindings. Subsequent bound keys override initial keys.
137155
138156
Args:
@@ -146,9 +164,7 @@ def merge(cls, bindings: Iterable[_Bindings]) -> _Bindings:
146164
for key, key_bindings in _bindings.keys.items():
147165
keys.setdefault(key, []).extend(key_bindings)
148166

149-
new_bindings = _Bindings()
150-
new_bindings.keys = keys
151-
return new_bindings
167+
return BindingsMap.from_keys(keys)
152168

153169
@property
154170
def shown_keys(self) -> list[Binding]:

src/textual/dom.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
from ._context import NoActiveAppError, active_message_pump
3333
from ._node_list import NodeList
3434
from ._types import WatchCallbackType
35-
from .binding import Binding, BindingType, _Bindings
35+
from .binding import Binding, BindingsMap, BindingType
3636
from .color import BLACK, WHITE, Color
3737
from .css._error_tools import friendly_list
3838
from .css.constants import VALID_DISPLAY, VALID_VISIBILITY
@@ -158,7 +158,7 @@ class DOMNode(MessagePump):
158158
_css_type_name: str = ""
159159

160160
# Generated list of bindings
161-
_merged_bindings: ClassVar[_Bindings | None] = None
161+
_merged_bindings: ClassVar[BindingsMap | None] = None
162162

163163
_reactives: ClassVar[dict[str, Reactive]]
164164

@@ -197,7 +197,7 @@ def __init__(
197197
self._auto_refresh_timer: Timer | None = None
198198
self._css_types = {cls.__name__ for cls in self._css_bases(self.__class__)}
199199
self._bindings = (
200-
_Bindings()
200+
BindingsMap()
201201
if self._merged_bindings is None
202202
else self._merged_bindings.copy()
203203
)
@@ -590,27 +590,30 @@ def _css_bases(cls, base: Type[DOMNode]) -> Sequence[Type[DOMNode]]:
590590
return classes
591591

592592
@classmethod
593-
def _merge_bindings(cls) -> _Bindings:
593+
def _merge_bindings(cls) -> BindingsMap:
594594
"""Merge bindings from base classes.
595595
596596
Returns:
597597
Merged bindings.
598598
"""
599-
bindings: list[_Bindings] = []
599+
bindings: list[BindingsMap] = []
600600

601601
for base in reversed(cls.__mro__):
602602
if issubclass(base, DOMNode):
603603
if not base._inherit_bindings:
604604
bindings.clear()
605605
bindings.append(
606-
_Bindings(
606+
BindingsMap(
607607
base.__dict__.get("BINDINGS", []),
608608
)
609609
)
610-
keys: dict[str, Binding] = {}
610+
keys: dict[str, list[Binding]] = {}
611611
for bindings_ in bindings:
612-
keys.update(bindings_.keys)
613-
return _Bindings(keys.values())
612+
for key, key_bindings in bindings_.keys.items():
613+
keys.setdefault(key, []).extend(key_bindings)
614+
615+
new_bindings = BindingsMap().from_keys(keys)
616+
return new_bindings
614617

615618
def _post_register(self, app: App) -> None:
616619
"""Called when the widget is registered

src/textual/screen.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
from ._path import CSSPathType, _css_path_type_as_list, _make_path_object_relative
3737
from ._types import CallbackType
3838
from .await_complete import AwaitComplete
39-
from .binding import ActiveBinding, Binding, _Bindings
39+
from .binding import ActiveBinding, Binding, BindingsMap
4040
from .css.match import match
4141
from .css.parse import parse_selectors
4242
from .css.query import NoMatches, QueryType
@@ -289,12 +289,12 @@ def refresh_bindings(self) -> None:
289289
self.check_idle()
290290

291291
@property
292-
def _binding_chain(self) -> list[tuple[DOMNode, _Bindings]]:
292+
def _binding_chain(self) -> list[tuple[DOMNode, BindingsMap]]:
293293
"""Binding chain from this screen."""
294294
focused = self.focused
295295
if focused is not None and focused.loading:
296296
focused = None
297-
namespace_bindings: list[tuple[DOMNode, _Bindings]]
297+
namespace_bindings: list[tuple[DOMNode, BindingsMap]]
298298

299299
if focused is None:
300300
namespace_bindings = [
@@ -309,7 +309,7 @@ def _binding_chain(self) -> list[tuple[DOMNode, _Bindings]]:
309309
return namespace_bindings
310310

311311
@property
312-
def _modal_binding_chain(self) -> list[tuple[DOMNode, _Bindings]]:
312+
def _modal_binding_chain(self) -> list[tuple[DOMNode, BindingsMap]]:
313313
"""The binding chain, ignoring everything before the last modal."""
314314
binding_chain = self._binding_chain
315315
for index, (node, _bindings) in enumerate(binding_chain, 1):
@@ -332,7 +332,7 @@ def active_bindings(self) -> dict[str, ActiveBinding]:
332332

333333
bindings_map: dict[str, ActiveBinding] = {}
334334
for namespace, bindings in self._modal_binding_chain:
335-
for key, binding in bindings.keys.items():
335+
for key, binding in bindings:
336336
action_state = self.app._check_action_state(binding.action, namespace)
337337
if action_state is False:
338338
continue

tests/test_binding.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,13 @@
33
import pytest
44

55
from textual.app import App
6-
from textual.binding import Binding, BindingError, InvalidBinding, NoBinding, _Bindings
6+
from textual.binding import (
7+
Binding,
8+
BindingError,
9+
BindingsMap,
10+
InvalidBinding,
11+
NoBinding,
12+
)
713

814
BINDING1 = Binding("a,b", action="action1", description="description1")
915
BINDING2 = Binding("c", action="action2", description="description2")
@@ -12,12 +18,12 @@
1218

1319
@pytest.fixture
1420
def bindings():
15-
yield _Bindings([BINDING1, BINDING2])
21+
yield BindingsMap([BINDING1, BINDING2])
1622

1723

1824
@pytest.fixture
1925
def more_bindings():
20-
yield _Bindings([BINDING1, BINDING2, BINDING3])
26+
yield BindingsMap([BINDING1, BINDING2, BINDING3])
2127

2228

2329
def test_bindings_get_key(bindings):
@@ -34,38 +40,40 @@ def test_bindings_get_key_spaced_list(more_bindings):
3440

3541

3642
def test_bindings_merge_simple(bindings):
37-
left = _Bindings([BINDING1])
38-
right = _Bindings([BINDING2])
39-
assert _Bindings.merge([left, right]).keys == bindings.keys
43+
left = BindingsMap([BINDING1])
44+
right = BindingsMap([BINDING2])
45+
assert BindingsMap.merge([left, right]).keys == bindings.keys
4046

4147

4248
def test_bindings_merge_overlap():
43-
left = _Bindings([BINDING1])
49+
left = BindingsMap([BINDING1])
4450
another_binding = Binding(
4551
"a", action="another_action", description="another_description"
4652
)
47-
assert _Bindings.merge([left, _Bindings([another_binding])]).keys == {
53+
assert BindingsMap.merge([left, BindingsMap([another_binding])]).keys == {
4854
"a": another_binding,
4955
"b": Binding("b", action="action1", description="description1"),
5056
}
5157

5258

5359
def test_bad_binding_tuple():
5460
with pytest.raises(BindingError):
55-
_ = _Bindings((("a",),))
61+
_ = BindingsMap((("a",),))
5662
with pytest.raises(BindingError):
57-
_ = _Bindings((("a", "action", "description", "too much"),))
63+
_ = BindingsMap((("a", "action", "description", "too much"),))
5864

5965

6066
def test_binding_from_tuples():
6167
assert (
62-
_Bindings(((BINDING2.key, BINDING2.action, BINDING2.description),)).get_key("c")
68+
BindingsMap(((BINDING2.key, BINDING2.action, BINDING2.description),)).get_key(
69+
"c"
70+
)
6371
== BINDING2
6472
)
6573

6674

6775
def test_shown():
68-
bindings = _Bindings(
76+
bindings = BindingsMap(
6977
[
7078
Binding(
7179
key,

0 commit comments

Comments
 (0)