Skip to content
16 changes: 5 additions & 11 deletions pandas/core/indexes/interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,24 +312,18 @@ def __contains__(self, key):

def contains(self, key):
"""
return a boolean if this key is IN the index

We accept / allow keys to be not *just* actual
objects.
Return a boolean mask whether the key is contained in the Intervals
of the index.

Parameters
----------
key : int, float, Interval
key : scalar, Interval

Returns
-------
boolean
boolean array
"""
try:
self.get_loc(key)
return True
except KeyError:
return False
return np.array([key in interval for interval in self], dtype='bool')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can just do
return self.get_indexer(key) != -1


@classmethod
def from_breaks(cls, breaks, closed='right', name=None, copy=False):
Expand Down
30 changes: 19 additions & 11 deletions pandas/tests/indexes/test_interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,22 +564,30 @@ def test_contains(self):
assert Interval(3, 5) not in i
assert Interval(-1, 0, closed='left') not in i

def testcontains(self):
def test_contains_method(self):
# can select values that are IN the range of a value
i = IntervalIndex.from_arrays([0, 1], [1, 2])

assert i.contains(0.1)
assert i.contains(0.5)
assert i.contains(1)
assert i.contains(Interval(0, 1))
assert i.contains(Interval(0, 2))
expected = np.array([False, False], dtype='bool')
actual = i.contains(0)
tm.assert_numpy_array_equal(actual, expected)
actual = i.contains(3)
tm.assert_numpy_array_equal(actual, expected)

expected = np.array([True, False], dtype='bool')
actual = i.contains(0.5)
tm.assert_numpy_array_equal(actual, expected)
actual = i.contains(1)
tm.assert_numpy_array_equal(actual, expected)

# these overlaps completely
assert i.contains(Interval(0, 3))
assert i.contains(Interval(1, 3))
# TODO what to do with intervals?
# assert i.contains(Interval(0, 1))
# assert i.contains(Interval(0, 2))
#
# # these overlaps completely
# assert i.contains(Interval(0, 3))
# assert i.contains(Interval(1, 3))

assert not i.contains(20)
assert not i.contains(-20)

def test_dropna(self):

Expand Down