diff --git a/sortedcontainers/sorteddict.py b/sortedcontainers/sorteddict.py index 320b861..fdaca13 100644 --- a/sortedcontainers/sorteddict.py +++ b/sortedcontainers/sorteddict.py @@ -512,6 +512,33 @@ def peekitem(self, index=-1): key = self._list[index] return key, self[key] + def next_smaller_item(self, value): + """Return strictly smaller item than the given key.""" + i = self.bisect_left(value) - 1 + if i < 0: + return None, None + return self.peekitem(index=i) + + def floor_item(self, value): + """Return smaller or equal item than the given key.""" + i = self.bisect_right(value) - 1 + if i < 0: + return None, None + return self.peekitem(index=i) + + def ceil_item(self, value): + """Return larger or equal item than the given key.""" + i = self.bisect_left(value) + if i >= len(self): + return None, None + return self.peekitem(index=i) + + def next_greater_item(self, value): + """Return strictly larger item than the given key.""" + i = self.bisect_right(value) + if i >= len(self): + return None, None + return self.peekitem(index=i) def setdefault(self, key, default=None): """Return value for item identified by `key` in sorted dict. diff --git a/sortedcontainers/sortedlist.py b/sortedcontainers/sortedlist.py index e3b58eb..a857722 100644 --- a/sortedcontainers/sortedlist.py +++ b/sortedcontainers/sortedlist.py @@ -1229,6 +1229,33 @@ def bisect_right(self, value): bisect = bisect_right _bisect_right = bisect_right + def next_smaller(self, value): + """Return strictly smaller value to the given value.""" + i = self.bisect_left(value) - 1 + if i < 0: + return None + return self[i] + + def floor(self, value): + """Return smaller or equal value than the given value.""" + i = self.bisect_right(value) - 1 + if i < 0: + return None + return self[i] + + def ceil(self, value): + """Return larger or equal value than the given value.""" + i = self.bisect_left(value) + if i >= len(self): + return None + return self[i] + + def next_greater(self, value): + """Return strictly larger value than the given value.""" + i = self.bisect_right(value) + if i >= len(self): + return None + return self[i] def count(self, value): """Return number of occurrences of `value` in the sorted list. diff --git a/sortedcontainers/sortedset.py b/sortedcontainers/sortedset.py index be2b899..7fdac11 100644 --- a/sortedcontainers/sortedset.py +++ b/sortedcontainers/sortedset.py @@ -153,6 +153,10 @@ def __init__(self, iterable=None, key=None): self.bisect_left = _list.bisect_left self.bisect = _list.bisect self.bisect_right = _list.bisect_right + self.next_smaller = _list.next_smaller + self.floor = _list.floor + self.ceil = _list.ceil + self.next_greater = _list.next_greater self.index = _list.index self.irange = _list.irange self.islice = _list.islice diff --git a/tests/test_coverage_sorteddict.py b/tests/test_coverage_sorteddict.py index 9b62a7f..47bcc1c 100644 --- a/tests/test_coverage_sorteddict.py +++ b/tests/test_coverage_sorteddict.py @@ -357,6 +357,48 @@ def test_bisect_key2(): assert all(temp.bisect_key_right(val) == ((val % 10) + 1) * 10 for val in range(10)) assert all(temp.bisect_key_left(val) == (val % 10) * 10 for val in range(10)) + +# Test Values for adjacent search functions. +sd = SortedDict({1: "a", 2: "b", 4: "c", 5: "d"}) +k1 = 0 # The key that is smaller than any existing value. +k2 = 1 # The key that is the smallest in the instance. +k3 = 3 # The key that does not exists in the instance. +k4 = 4 # The key that exists in the instance. +k5 = 5 # The key that is the largest in the instance. +k6 = 6 # The key that is larger than any existing value. + +def test_next_smaller_item(): + assert sd.next_smaller_item(k1) == (None, None) + assert sd.next_smaller_item(k2) == (None, None) + assert sd.next_smaller_item(k3) == (2, "b") + assert sd.next_smaller_item(k4) == (2, "b") + assert sd.next_smaller_item(k5) == (4, "c") + assert sd.next_smaller_item(k6) == (5, "d") + +def test_floor_item(): + assert sd.floor_item(k1) == (None, None) + assert sd.floor_item(k2) == (1, "a") + assert sd.floor_item(k3) == (2, "b") + assert sd.floor_item(k4) == (4, "c") + assert sd.floor_item(k5) == (5, "d") + assert sd.floor_item(k6) == (5, "d") + +def test_ceil_item(): + assert sd.ceil_item(k1) == (1, "a") + assert sd.ceil_item(k2) == (1, "a") + assert sd.ceil_item(k3) == (4, "c") + assert sd.ceil_item(k4) == (4, "c") + assert sd.ceil_item(k5) == (5, "d") + assert sd.ceil_item(k6) == (None, None) + +def test_next_greater_item(): + assert sd.next_greater_item(k1) == (1, "a") + assert sd.next_greater_item(k2) == (2, "b") + assert sd.next_greater_item(k3) == (4, "c") + assert sd.next_greater_item(k4) == (5, "d") + assert sd.next_greater_item(k5) == (None, None) + assert sd.next_greater_item(k6) == (None, None) + def test_keysview(): mapping = [(val, pos) for pos, val in enumerate(string.ascii_lowercase)] temp = SortedDict(mapping[:13]) diff --git a/tests/test_coverage_sortedlist.py b/tests/test_coverage_sortedlist.py index 7e009e1..f8bbf4d 100644 --- a/tests/test_coverage_sortedlist.py +++ b/tests/test_coverage_sortedlist.py @@ -368,6 +368,48 @@ def test_bisect_right(): assert slt.bisect_right(10) == 22 assert slt.bisect_right(200) == 200 + +# Test Values for adjacent search functions. +slt = SortedList([1, 2, 4, 4, 5]) +v1 = 0 # The value that is smaller than any existing value. +v2 = 1 # The value that is the smallest in the instance. +v3 = 3 # The value that does not exists in the instance. +v4 = 4 # The value that exists in the instance. +v5 = 5 # The value that is the largest in the instance. +v6 = 6 # The value that is larger than any existing value. + +def test_next_smaller(): + assert slt.next_smaller(v1) is None + assert slt.next_smaller(v2) is None + assert slt.next_smaller(v3) == 2 + assert slt.next_smaller(v4) == 2 + assert slt.next_smaller(v5) == 4 + assert slt.next_smaller(v6) == 5 + +def test_floor(): + assert slt.floor(v1) is None + assert slt.floor(v2) == 1 + assert slt.floor(v3) == 2 + assert slt.floor(v4) == 4 + assert slt.floor(v5) == 5 + assert slt.floor(v6) == 5 + +def test_ceil(): + assert slt.ceil(v1) == 1 + assert slt.ceil(v2) == 1 + assert slt.ceil(v3) == 4 + assert slt.ceil(v4) == 4 + assert slt.ceil(v5) == 5 + assert slt.ceil(v6) is None + +def test_next_greater(): + assert slt.next_greater(v1) == 1 + assert slt.next_greater(v2) == 2 + assert slt.next_greater(v3) == 4 + assert slt.next_greater(v4) == 5 + assert slt.next_greater(v5) is None + assert slt.next_greater(v6) is None + def test_copy(): alpha = SortedList(range(100)) alpha._reset(7) diff --git a/tests/test_coverage_sortedset.py b/tests/test_coverage_sortedset.py index e67487d..a6d2cf4 100644 --- a/tests/test_coverage_sortedset.py +++ b/tests/test_coverage_sortedset.py @@ -255,6 +255,48 @@ def test_bisect_key(): assert all(temp.bisect_key(val) == (val + 1) for val in range(100)) assert all(temp.bisect_key_right(val) == (val + 1) for val in range(100)) + +# Test Values for adjacent search functions. +ss = SortedSet([1, 2, 4, 5]) +v1 = 0 # The value that is smaller than any existing value. +v2 = 1 # The value that is the smallest in the instance. +v3 = 3 # The value that does not exists in the instance. +v4 = 4 # The value that exists in the instance. +v5 = 5 # The value that is the largest in the instance. +v6 = 6 # The value that is larger than any existing value. + +def test_next_smaller(): + assert ss.next_smaller(v1) is None + assert ss.next_smaller(v2) is None + assert ss.next_smaller(v3) == 2 + assert ss.next_smaller(v4) == 2 + assert ss.next_smaller(v5) == 4 + assert ss.next_smaller(v6) == 5 + +def test_floor(): + assert ss.floor(v1) is None + assert ss.floor(v2) == 1 + assert ss.floor(v3) == 2 + assert ss.floor(v4) == 4 + assert ss.floor(v5) == 5 + assert ss.floor(v6) == 5 + +def test_ceil(): + assert ss.ceil(v1) == 1 + assert ss.ceil(v2) == 1 + assert ss.ceil(v3) == 4 + assert ss.ceil(v4) == 4 + assert ss.ceil(v5) == 5 + assert ss.ceil(v6) is None + +def test_next_greater(): + assert ss.next_greater(v1) == 1 + assert ss.next_greater(v2) == 2 + assert ss.next_greater(v3) == 4 + assert ss.next_greater(v4) == 5 + assert ss.next_greater(v5) is None + assert ss.next_greater(v6) is None + def test_clear(): temp = SortedSet(range(100)) temp._reset(7)