Skip to content

Commit d701b64

Browse files
authored
Merge pull request #152 from chaimleib/feat/147/copy-returns-subclasses
better subclassing support: methods creating new instances dynamically determine the instance class
2 parents 5f9a730 + 8a348f6 commit d701b64

File tree

9 files changed

+169
-68
lines changed

9 files changed

+169
-68
lines changed

CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
# Change log
22

3+
## Version 3.2.2
4+
- Fixed:
5+
- Better subclassing support: Determine classes dynamically,
6+
so that methods like str() are aware when our types are subclassed.
7+
38
## Version 3.2.1
49
- Fixed:
510
- Build system includes sortedcontainers dependency in the wheel again

intervaltree/interval.py

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ class Interval(namedtuple('IntervalBase', ['begin', 'end', 'data'])):
3232

3333
def __new__(cls, begin, end, data=None):
3434
return super(Interval, cls).__new__(cls, begin, end, data)
35-
35+
3636
def overlaps(self, begin, end=None):
3737
"""
3838
Whether the interval overlaps the given point, range or Interval.
@@ -44,7 +44,7 @@ def overlaps(self, begin, end=None):
4444
if end is not None:
4545
# An overlap means that some C exists that is inside both ranges:
4646
# begin <= C < end
47-
# and
47+
# and
4848
# self.begin <= C < self.end
4949
# See https://stackoverflow.com/questions/3269434/whats-the-most-efficient-way-to-test-two-integer-ranges-for-overlap/3269471#3269471
5050
return begin < self.end and end > self.begin
@@ -84,7 +84,7 @@ def contains_point(self, p):
8484
:rtype: bool
8585
"""
8686
return self.begin <= p < self.end
87-
87+
8888
def range_matches(self, other):
8989
"""
9090
Whether the begins equal and the ends equal. Compare __eq__().
@@ -93,10 +93,10 @@ def range_matches(self, other):
9393
:rtype: bool
9494
"""
9595
return (
96-
self.begin == other.begin and
96+
self.begin == other.begin and
9797
self.end == other.end
9898
)
99-
99+
100100
def contains_interval(self, other):
101101
"""
102102
Whether other is contained in this Interval.
@@ -108,10 +108,10 @@ def contains_interval(self, other):
108108
self.begin <= other.begin and
109109
self.end >= other.end
110110
)
111-
111+
112112
def distance_to(self, other):
113113
"""
114-
Returns the size of the gap between intervals, or 0
114+
Returns the size of the gap between intervals, or 0
115115
if they touch or overlap.
116116
:param other: Interval or point
117117
:return: distance
@@ -291,7 +291,7 @@ def _get_fields(self):
291291
return self.begin, self.end, self.data
292292
else:
293293
return self.begin, self.end
294-
294+
295295
def __repr__(self):
296296
"""
297297
Executable string representation of this Interval.
@@ -305,9 +305,18 @@ def __repr__(self):
305305
s_begin = repr(self.begin)
306306
s_end = repr(self.end)
307307
if self.data is None:
308-
return "Interval({0}, {1})".format(s_begin, s_end)
308+
return "{0}({1}, {2})".format(
309+
self.__class__.__name__,
310+
s_begin,
311+
s_end,
312+
)
309313
else:
310-
return "Interval({0}, {1}, {2})".format(s_begin, s_end, repr(self.data))
314+
return "{0}({1}, {2}, {3})".format(
315+
self.__class__.__name__,
316+
s_begin,
317+
s_end,
318+
repr(self.data),
319+
)
311320

312321
__str__ = __repr__
313322

@@ -317,12 +326,12 @@ def copy(self):
317326
:return: copy of self
318327
:rtype: Interval
319328
"""
320-
return Interval(self.begin, self.end, self.data)
321-
329+
return self.__class__(self.begin, self.end, self.data)
330+
322331
def __reduce__(self):
323332
"""
324333
For pickle-ing.
325334
:return: pickle data
326335
:rtype: tuple
327336
"""
328-
return Interval, self._get_fields()
337+
return self.__class__, self._get_fields()

intervaltree/intervaltree.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ def from_tuples(cls, tups):
247247
where the tuple lists begin, end, and optionally data.
248248
"""
249249
ivs = [Interval(*t) for t in tups]
250-
return IntervalTree(ivs)
250+
return cls(ivs)
251251

252252
def __init__(self, intervals=None):
253253
"""
@@ -277,7 +277,7 @@ def copy(self):
277277
Completes in O(n*log n) time.
278278
:rtype: IntervalTree
279279
"""
280-
return IntervalTree(iv.copy() for iv in self)
280+
return self.__class__(iv.copy() for iv in self)
281281

282282
def _add_boundaries(self, interval):
283283
"""
@@ -407,7 +407,7 @@ def difference(self, other):
407407
for iv in self:
408408
if iv not in other:
409409
ivs.add(iv)
410-
return IntervalTree(ivs)
410+
return self.__class__(ivs)
411411

412412
def difference_update(self, other):
413413
"""
@@ -421,7 +421,7 @@ def union(self, other):
421421
Returns a new tree, comprising all intervals from self
422422
and other.
423423
"""
424-
return IntervalTree(set(self).union(other))
424+
return self.__class__(set(self).union(other))
425425

426426
def intersection(self, other):
427427
"""
@@ -433,7 +433,7 @@ def intersection(self, other):
433433
for iv in shorter:
434434
if iv in longer:
435435
ivs.add(iv)
436-
return IntervalTree(ivs)
436+
return self.__class__(ivs)
437437

438438
def intersection_update(self, other):
439439
"""
@@ -452,7 +452,7 @@ def symmetric_difference(self, other):
452452
if not isinstance(other, set): other = set(other)
453453
me = set(self)
454454
ivs = me.difference(other).union(other.difference(me))
455-
return IntervalTree(ivs)
455+
return self.__class__(ivs)
456456

457457
def symmetric_difference_update(self, other):
458458
"""
@@ -1193,7 +1193,7 @@ def __eq__(self, other):
11931193
:rtype: bool
11941194
"""
11951195
return (
1196-
isinstance(other, IntervalTree) and
1196+
isinstance(other, self.__class__) and
11971197
self.all_intervals == other.all_intervals
11981198
)
11991199

@@ -1203,9 +1203,9 @@ def __repr__(self):
12031203
"""
12041204
ivs = sorted(self)
12051205
if not ivs:
1206-
return "IntervalTree()"
1206+
return "{0}()".format(self.__class__.__name__)
12071207
else:
1208-
return "IntervalTree({0})".format(ivs)
1208+
return "{0}({1})".format(self.__class__.__name__, ivs)
12091209

12101210
__str__ = __repr__
12111211

@@ -1214,5 +1214,5 @@ def __reduce__(self):
12141214
For pickle-ing.
12151215
:rtype: tuple
12161216
"""
1217-
return IntervalTree, (sorted(self.all_intervals),)
1217+
return self.__class__, (sorted(self.all_intervals),)
12181218

intervaltree/node.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def from_interval(cls, interval):
6262
:rtype : Node
6363
"""
6464
center = interval.begin
65-
return Node(center, [interval])
65+
return cls(center, [interval])
6666

6767
@classmethod
6868
def from_intervals(cls, intervals):
@@ -71,7 +71,7 @@ def from_intervals(cls, intervals):
7171
"""
7272
if not intervals:
7373
return None
74-
return Node.from_sorted_intervals(sorted(intervals))
74+
return cls.from_sorted_intervals(sorted(intervals))
7575

7676
@classmethod
7777
def from_sorted_intervals(cls, intervals):
@@ -80,7 +80,7 @@ def from_sorted_intervals(cls, intervals):
8080
"""
8181
if not intervals:
8282
return None
83-
node = Node()
83+
node = cls()
8484
node = node.init_from_sorted(intervals)
8585
return node
8686

@@ -99,8 +99,8 @@ def init_from_sorted(self, intervals):
9999
s_right.append(k)
100100
else:
101101
self.s_center.add(k)
102-
self.left_node = Node.from_sorted_intervals(s_left)
103-
self.right_node = Node.from_sorted_intervals(s_right)
102+
self.left_node = self.__class__.from_sorted_intervals(s_left)
103+
self.right_node = self.__class__.from_sorted_intervals(s_right)
104104
return self.rotate()
105105

106106
def center_hit(self, interval):
@@ -212,7 +212,7 @@ def add(self, interval):
212212
else:
213213
direction = self.hit_branch(interval)
214214
if not self[direction]:
215-
self[direction] = Node.from_interval(interval)
215+
self[direction] = self.__class__.from_interval(interval)
216216
self.refresh_balance()
217217
return self
218218
else:
@@ -392,7 +392,7 @@ def get_new_s_center():
392392
if iv.contains_point(new_x_center): yield iv
393393

394394
# Create a new node with the largest x_center possible.
395-
child = Node(new_x_center, get_new_s_center())
395+
child = self.__class__(new_x_center, get_new_s_center())
396396
self.s_center -= child.s_center
397397

398398
#print('Pop hit! Returning child = {}'.format(
@@ -527,7 +527,8 @@ def __str__(self):
527527
user, I'm not bothering to make this copy-paste-executable as a
528528
constructor.
529529
"""
530-
return "Node<{0}, depth={1}, balance={2}>".format(
530+
return "{0}<{1}, depth={2}, balance={3}>".format(
531+
self.__class__.__name__,
531532
self.x_center,
532533
self.depth,
533534
self.balance

test/interval_methods/copy_test.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
from intervaltree import Interval
2+
import pickle
3+
4+
def test_copy():
5+
iv0 = Interval(1, 2, 3)
6+
iv1 = iv0.copy()
7+
assert iv1.begin == iv0.begin
8+
assert iv1.end == iv0.end
9+
assert iv1.data == iv0.data
10+
assert iv1 == iv0
11+
12+
iv2 = pickle.loads(pickle.dumps(iv0))
13+
assert iv2.begin == iv0.begin
14+
assert iv2.end == iv0.end
15+
assert iv2.data == iv0.data
16+
assert iv2 == iv0
17+
18+
19+
def test_copy_type():
20+
class MyInterval(Interval):
21+
pass
22+
iv = MyInterval(1, 2)
23+
c = iv.copy()
24+
assert isinstance(c, MyInterval)
25+
26+
27+
if __name__ == "__main__":
28+
import pytest
29+
pytest.main([__file__, '-v'])

test/interval_methods/str_test.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
from intervaltree import Interval
2+
3+
def test_str():
4+
iv = Interval(0, 1)
5+
s = str(iv)
6+
assert s == 'Interval(0, 1)'
7+
assert repr(iv) == s
8+
9+
iv = Interval(0, 1, '[0,1)')
10+
s = str(iv)
11+
assert s == "Interval(0, 1, '[0,1)')"
12+
assert repr(iv) == s
13+
14+
iv = Interval((1,2), (3,4))
15+
s = str(iv)
16+
assert s == 'Interval((1, 2), (3, 4))'
17+
assert repr(iv) == s
18+
19+
iv = Interval((1,2), (3,4), (5, 6))
20+
s = str(iv)
21+
assert s == 'Interval((1, 2), (3, 4), (5, 6))'
22+
assert repr(iv) == s
23+
24+
25+
def test_str_type():
26+
class MyInterval(Interval):
27+
pass
28+
29+
iv = MyInterval(0, 1)
30+
s = str(iv)
31+
assert s == 'MyInterval(0, 1)'
32+
assert repr(iv) == s
33+
34+
iv = MyInterval(0, 1, '[0,1)')
35+
s = str(iv)
36+
assert s == "MyInterval(0, 1, '[0,1)')"
37+
assert repr(iv) == s
38+
39+
iv = MyInterval((1,2), (3,4))
40+
s = str(iv)
41+
assert s == 'MyInterval((1, 2), (3, 4))'
42+
assert repr(iv) == s
43+
44+
iv = MyInterval((1,2), (3,4), (5, 6))
45+
s = str(iv)
46+
assert s == 'MyInterval((1, 2), (3, 4), (5, 6))'
47+
assert repr(iv) == s
48+
49+
50+
if __name__ == "__main__":
51+
import pytest
52+
pytest.main([__file__, '-v'])

test/interval_methods/unary_test.py

Lines changed: 0 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -32,21 +32,6 @@ def test_isnull():
3232
assert iv.is_null()
3333

3434

35-
def test_copy():
36-
iv0 = Interval(1, 2, 3)
37-
iv1 = iv0.copy()
38-
assert iv1.begin == iv0.begin
39-
assert iv1.end == iv0.end
40-
assert iv1.data == iv0.data
41-
assert iv1 == iv0
42-
43-
iv2 = pickle.loads(pickle.dumps(iv0))
44-
assert iv2.begin == iv0.begin
45-
assert iv2.end == iv0.end
46-
assert iv2.data == iv0.data
47-
assert iv2 == iv0
48-
49-
5035
def test_len():
5136
iv = Interval(0, 0)
5237
assert len(iv) == 3
@@ -72,28 +57,6 @@ def test_length():
7257
assert iv.length() == 2.9
7358

7459

75-
def test_str():
76-
iv = Interval(0, 1)
77-
s = str(iv)
78-
assert s == 'Interval(0, 1)'
79-
assert repr(iv) == s
80-
81-
iv = Interval(0, 1, '[0,1)')
82-
s = str(iv)
83-
assert s == "Interval(0, 1, '[0,1)')"
84-
assert repr(iv) == s
85-
86-
iv = Interval((1,2), (3,4))
87-
s = str(iv)
88-
assert s == 'Interval((1, 2), (3, 4))'
89-
assert repr(iv) == s
90-
91-
iv = Interval((1,2), (3,4), (5, 6))
92-
s = str(iv)
93-
assert s == 'Interval((1, 2), (3, 4), (5, 6))'
94-
assert repr(iv) == s
95-
96-
9760
def test_get_fields():
9861
ivn = Interval(0, 1)
9962
ivo = Interval(0, 1, 'hello')

0 commit comments

Comments
 (0)