Skip to content

Commit c0d9fa4

Browse files
committed
Implement collections.abc.Mapping for Multiaddr to give it a dict-style feel
Nit: All operations on this “dictionary” (including __len__ and __getitem__) are O(n) over the number of bytes/protocols of the MultiAddr, rather then using more optimized code that would require pre-parsing the binary MultiAddr. Considering that MultiAddr tend to be rather short this should hopefully not be an issue, but changing this would be possible in the future without changing any part of the interface.
1 parent 953c503 commit c0d9fa4

File tree

4 files changed

+147
-27
lines changed

4 files changed

+147
-27
lines changed

multiaddr/multiaddr.py

Lines changed: 99 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,9 @@
11
# -*- coding: utf-8 -*-
2+
try:
3+
import collections.abc
4+
except ImportError: # pragma: no cover (PY2)
5+
import collections
6+
collections.abc = collections
27
from copy import copy
38

49
import six
@@ -10,7 +15,76 @@
1015
from .transforms import bytes_to_string
1116

1217

13-
class Multiaddr(object):
18+
__all__ = ("Multiaddr",)
19+
20+
21+
22+
class MultiAddrKeys(collections.abc.KeysView, collections.abc.Sequence):
23+
def __contains__(self, proto):
24+
proto = protocols.protocol_with_any(proto)
25+
return collections.abc.Sequence.__contains__(self, proto)
26+
27+
def __getitem__(self, idx):
28+
if idx < 0:
29+
idx = len(self)+idx
30+
for idx2, proto in enumerate(self):
31+
if idx2 == idx:
32+
return proto
33+
raise IndexError("Protocol list index out of range")
34+
35+
__hash__ = collections.abc.KeysView._hash
36+
37+
def __iter__(self):
38+
for proto, _, _ in bytes_iter(self._mapping.to_bytes()):
39+
yield proto
40+
41+
42+
class MultiAddrItems(collections.abc.ItemsView, collections.abc.Sequence):
43+
def __contains__(self, item):
44+
proto, value = item
45+
proto = protocols.protocol_with_any(proto)
46+
return collections.abc.Sequence.__contains__(self, (proto, value))
47+
48+
def __getitem__(self, idx):
49+
if idx < 0:
50+
idx = len(self)+idx
51+
for idx2, item in enumerate(self):
52+
if idx2 == idx:
53+
return item
54+
raise IndexError("Protocol item list index out of range")
55+
56+
def __iter__(self):
57+
for proto, codec, part in bytes_iter(self._mapping.to_bytes()):
58+
if codec.SIZE != 0:
59+
try:
60+
# If we have an address, return it
61+
yield proto, codec.to_string(proto, part)
62+
except Exception as exc:
63+
six.raise_from(exceptions.BinaryParseError(str(exc), self._mapping.to_bytes(), proto.name, exc), exc)
64+
else:
65+
# We were given something like '/utp', which doesn't have
66+
# an address, so return None
67+
yield proto, None
68+
69+
70+
class MultiAddrValues(collections.abc.ValuesView, collections.abc.Sequence):
71+
__contains__ = collections.abc.Sequence.__contains__
72+
73+
def __getitem__(self, idx):
74+
if idx < 0:
75+
idx = len(self)+idx
76+
for idx2, proto in enumerate(self):
77+
if idx2 == idx:
78+
return proto
79+
raise IndexError("Protocol value list index out of range")
80+
81+
def __iter__(self):
82+
for _, value in MultiAddrItems(self._mapping):
83+
yield value
84+
85+
86+
87+
class Multiaddr(collections.abc.Mapping):
1488
"""Multiaddr is a representation of multiple nested internet addresses.
1589
1690
Multiaddr is a cross-protocol, cross-platform format for representing
@@ -53,16 +127,22 @@ def __eq__(self, other):
53127
"""Checks if two Multiaddr objects are exactly equal."""
54128
return self._bytes == other._bytes
55129

56-
def __ne__(self, other):
57-
return not (self == other)
58-
59130
def __str__(self):
60131
"""Return the string representation of this Multiaddr.
61132
62133
May raise a :class:`~multiaddr.exceptions.BinaryParseError` if the
63134
stored MultiAddr binary representation is invalid."""
64135
return bytes_to_string(self._bytes)
65136

137+
def __contains__(self, proto):
138+
return proto in MultiAddrKeys(self)
139+
140+
def __iter__(self):
141+
return iter(MultiAddrKeys(self))
142+
143+
def __len__(self):
144+
return sum((1 for _ in bytes_iter(self.to_bytes())))
145+
66146
# On Python 2 __str__ needs to return binary text, so expose the original
67147
# function as __unicode__ and transparently encode its returned text based
68148
# on the current locale
@@ -82,7 +162,15 @@ def to_bytes(self):
82162

83163
def protocols(self):
84164
"""Returns a list of Protocols this Multiaddr includes."""
85-
return list(proto for proto, _, _ in bytes_iter(self.to_bytes()))
165+
return MultiAddrKeys(self)
166+
167+
keys = protocols
168+
169+
def items(self):
170+
return MultiAddrItems(self)
171+
172+
def values(self):
173+
return MultiAddrValues(self)
86174

87175
def encapsulate(self, other):
88176
"""Wrap this Multiaddr around another.
@@ -125,24 +213,10 @@ def value_for_protocol(self, proto):
125213
~multiaddr.exceptions.ProtocolLookupError
126214
MultiAddr does not contain any instance of this protocol
127215
"""
128-
if not isinstance(proto, protocols.Protocol):
129-
if isinstance(proto, int):
130-
proto = protocols.protocol_with_code(proto)
131-
elif isinstance(proto, six.string_types):
132-
proto = protocols.protocol_with_name(proto)
133-
else:
134-
raise TypeError("Protocol object, name or code expected, got {0!r}".format(proto))
135-
136-
for proto2, codec, part in bytes_iter(self.to_bytes()):
137-
if proto2 == proto:
138-
if codec.SIZE != 0:
139-
try:
140-
# If we have an address, return it
141-
return codec.to_string(proto2, part)
142-
except Exception as exc:
143-
six.raise_from(exceptions.BinaryParseError(str(exc), self.to_bytes(), proto2.name, exc), exc)
144-
else:
145-
# We were given something like '/utp', which doesn't have
146-
# an address, so return None
147-
return None
216+
proto = protocols.protocol_with_any(proto)
217+
for proto2, value in self.items():
218+
if proto2 is proto or proto2 == proto:
219+
return value
148220
raise exceptions.ProtocolLookupError(proto, str(self))
221+
222+
__getitem__ = value_for_protocol

multiaddr/protocols.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,13 +100,16 @@ def vcode(self):
100100
return varint.encode(self.code)
101101

102102
def __eq__(self, other):
103+
if not isinstance(other, Protocol):
104+
return NotImplemented
105+
103106
return all((self.code == other.code,
104107
self.name == other.name,
105108
self.codec == other.codec,
106109
self.path == other.path))
107110

108-
def __ne__(self, other):
109-
return not self == other
111+
def __hash__(self):
112+
return self.code
110113

111114
def __repr__(self):
112115
return "Protocol(code={code!r}, name={name!r}, codec={codec!r})".format(
@@ -200,6 +203,17 @@ def protocol_with_code(code):
200203
return _codes_to_protocols[code]
201204

202205

206+
def protocol_with_any(proto):
207+
if isinstance(proto, Protocol):
208+
return proto
209+
elif isinstance(proto, int):
210+
return protocol_with_code(proto)
211+
elif isinstance(proto, six.string_types):
212+
return protocol_with_name(proto)
213+
else:
214+
raise TypeError("Protocol object, name or code expected, got {0!r}".format(proto))
215+
216+
203217
def protocols_with_string(string):
204218
"""Return a list of protocols matching given string."""
205219
# Normalize string

tests/test_multiaddr.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,34 @@ def test_get_value():
234234
assert_value_for_proto(a, P_UNIX, "/studio") # only a path.
235235

236236

237+
def test_views():
238+
ma = Multiaddr(
239+
"/ip4/127.0.0.1/utp/tcp/5555/udp/1234/utp/"
240+
"p2p/QmbHVEEepCi7rn7VL7Exxpd2Ci9NNB6ifvqwhsrbRMgQFP")
241+
242+
for idx, (proto1, proto2, item, value) in enumerate(zip(ma, ma.keys(), ma.items(), ma.values())):
243+
assert (proto1, value) == (proto2, value) == item
244+
assert proto1 in ma
245+
assert proto2 in ma.keys()
246+
assert item in ma.items()
247+
assert value in ma.values()
248+
assert ma.keys()[idx] == ma.keys()[idx-len(ma)] == proto1 == proto2
249+
assert ma.items()[idx] == ma.items()[idx-len(ma)] == item
250+
assert ma.values()[idx] == ma.values()[idx-len(ma)] == ma[proto1] == value
251+
252+
assert len(ma.keys()) == len(ma.items()) == len(ma.values()) == len(ma)
253+
assert len(list(ma.keys())) == len(ma.keys())
254+
assert len(list(ma.items())) == len(ma.items())
255+
assert len(list(ma.values())) == len(ma.values())
256+
257+
with pytest.raises(IndexError):
258+
ma.keys()[len(ma)]
259+
with pytest.raises(IndexError):
260+
ma.items()[len(ma)]
261+
with pytest.raises(IndexError):
262+
ma.values()[len(ma)]
263+
264+
237265
def test_bad_initialization_no_params():
238266
with pytest.raises(TypeError):
239267
Multiaddr()

tests/test_protocols.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ def test_protocol_with_name():
7272
assert proto.code == protocols.P_IP4
7373
assert proto.size == 32
7474
assert proto.vcode == varint.encode(protocols.P_IP4)
75+
assert hash(proto) == protocols.P_IP4
7576

7677
with pytest.raises(exceptions.ProtocolNotFoundError):
7778
proto = protocols.protocol_with_name('foo')
@@ -83,6 +84,7 @@ def test_protocol_with_code():
8384
assert proto.code == protocols.P_IP4
8485
assert proto.size == 32
8586
assert proto.vcode == varint.encode(protocols.P_IP4)
87+
assert hash(proto) == protocols.P_IP4
8688

8789
with pytest.raises(exceptions.ProtocolNotFoundError):
8890
proto = protocols.protocol_with_code(1234)
@@ -95,6 +97,8 @@ def test_protocol_equality():
9597

9698
assert proto1 == proto2
9799
assert proto1 != proto3
100+
assert proto1 != None
101+
assert proto2 != str(proto2)
98102

99103

100104
@pytest.mark.parametrize("names", [['ip4'],

0 commit comments

Comments
 (0)