Skip to content

Commit bbbcaf3

Browse files
authored
feat: implement Index.get_loc (#1921)
* feat: add index get_loc API * update docstring * code update * final polish of the helper function * fix mypy * reset index of result * change docstring * fix docstring * change a function call
1 parent 92a2377 commit bbbcaf3

File tree

3 files changed

+263
-0
lines changed

3 files changed

+263
-0
lines changed

bigframes/core/indexes/base.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,16 +27,21 @@
2727
import pandas
2828

2929
from bigframes import dtypes
30+
from bigframes.core.array_value import ArrayValue
3031
import bigframes.core.block_transforms as block_ops
3132
import bigframes.core.blocks as blocks
3233
import bigframes.core.expression as ex
34+
import bigframes.core.identifiers as ids
35+
import bigframes.core.nodes as nodes
3336
import bigframes.core.ordering as order
3437
import bigframes.core.utils as utils
3538
import bigframes.core.validations as validations
39+
import bigframes.core.window_spec as window_spec
3640
import bigframes.dtypes
3741
import bigframes.formatting_helpers as formatter
3842
import bigframes.operations as ops
3943
import bigframes.operations.aggregations as agg_ops
44+
import bigframes.series
4045

4146
if typing.TYPE_CHECKING:
4247
import bigframes.dataframe
@@ -247,6 +252,118 @@ def query_job(self) -> bigquery.QueryJob:
247252
self._query_job = query_job
248253
return self._query_job
249254

255+
def get_loc(self, key) -> typing.Union[int, slice, "bigframes.series.Series"]:
256+
"""Get integer location, slice or boolean mask for requested label.
257+
258+
Args:
259+
key:
260+
The label to search for in the index.
261+
262+
Returns:
263+
An integer, slice, or boolean mask representing the location(s) of the key.
264+
265+
Raises:
266+
NotImplementedError: If the index has more than one level.
267+
KeyError: If the key is not found in the index.
268+
"""
269+
if self.nlevels != 1:
270+
raise NotImplementedError("get_loc only supports single-level indexes")
271+
272+
# Get the index column from the block
273+
index_column = self._block.index_columns[0]
274+
275+
# Apply row numbering to the original data
276+
row_number_column_id = ids.ColumnId.unique()
277+
window_node = nodes.WindowOpNode(
278+
child=self._block._expr.node,
279+
expression=ex.NullaryAggregation(agg_ops.RowNumberOp()),
280+
window_spec=window_spec.unbound(),
281+
output_name=row_number_column_id,
282+
never_skip_nulls=True,
283+
)
284+
285+
windowed_array = ArrayValue(window_node)
286+
windowed_block = blocks.Block(
287+
windowed_array,
288+
index_columns=self._block.index_columns,
289+
column_labels=self._block.column_labels.insert(
290+
len(self._block.column_labels), None
291+
),
292+
index_labels=self._block._index_labels,
293+
)
294+
295+
# Create expression to find matching positions
296+
match_expr = ops.eq_op.as_expr(ex.deref(index_column), ex.const(key))
297+
windowed_block, match_col_id = windowed_block.project_expr(match_expr)
298+
299+
# Filter to only rows where the key matches
300+
filtered_block = windowed_block.filter_by_id(match_col_id)
301+
302+
# Check if key exists at all by counting on the filtered block
303+
count_agg = ex.UnaryAggregation(
304+
agg_ops.count_op, ex.deref(row_number_column_id.name)
305+
)
306+
count_result = filtered_block._expr.aggregate([(count_agg, "count")])
307+
count_scalar = self._block.session._executor.execute(
308+
count_result
309+
).to_py_scalar()
310+
311+
if count_scalar == 0:
312+
raise KeyError(f"'{key}' is not in index")
313+
314+
# If only one match, return integer position
315+
if count_scalar == 1:
316+
min_agg = ex.UnaryAggregation(
317+
agg_ops.min_op, ex.deref(row_number_column_id.name)
318+
)
319+
position_result = filtered_block._expr.aggregate([(min_agg, "position")])
320+
position_scalar = self._block.session._executor.execute(
321+
position_result
322+
).to_py_scalar()
323+
return int(position_scalar)
324+
325+
# Handle multiple matches based on index monotonicity
326+
is_monotonic = self.is_monotonic_increasing or self.is_monotonic_decreasing
327+
if is_monotonic:
328+
return self._get_monotonic_slice(filtered_block, row_number_column_id)
329+
else:
330+
# Return boolean mask for non-monotonic duplicates
331+
mask_block = windowed_block.select_columns([match_col_id])
332+
# Reset the index to use positional integers instead of original index values
333+
mask_block = mask_block.reset_index(drop=True)
334+
# Ensure correct dtype and name to match pandas behavior
335+
result_series = bigframes.series.Series(mask_block)
336+
return result_series.astype("boolean")
337+
338+
def _get_monotonic_slice(
339+
self, filtered_block, row_number_column_id: "ids.ColumnId"
340+
) -> slice:
341+
"""Helper method to get a slice for monotonic duplicates with an optimized query."""
342+
# Combine min and max aggregations into a single query for efficiency
343+
min_max_aggs = [
344+
(
345+
ex.UnaryAggregation(
346+
agg_ops.min_op, ex.deref(row_number_column_id.name)
347+
),
348+
"min_pos",
349+
),
350+
(
351+
ex.UnaryAggregation(
352+
agg_ops.max_op, ex.deref(row_number_column_id.name)
353+
),
354+
"max_pos",
355+
),
356+
]
357+
combined_result = filtered_block._expr.aggregate(min_max_aggs)
358+
359+
# Execute query and extract positions
360+
result_df = self._block.session._executor.execute(combined_result).to_pandas()
361+
min_pos = int(result_df["min_pos"].iloc[0])
362+
max_pos = int(result_df["max_pos"].iloc[0])
363+
364+
# Create slice (stop is exclusive)
365+
return slice(min_pos, max_pos + 1)
366+
250367
def __repr__(self) -> str:
251368
# Protect against errors with uninitialized Series. See:
252369
# https://github.com/googleapis/python-bigquery-dataframes/issues/728

tests/system/small/test_index.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,110 @@ def test_index_construct_from_list():
3232
pd.testing.assert_index_equal(bf_result, pd_result)
3333

3434

35+
@pytest.mark.parametrize("key, expected_loc", [("a", 0), ("b", 1), ("c", 2)])
36+
def test_get_loc_should_return_int_for_unique_index(key, expected_loc):
37+
"""Behavior: get_loc on a unique index returns an integer position."""
38+
# The pandas result is used as the known-correct value.
39+
# We assert our implementation matches it and the expected type.
40+
bf_index = bpd.Index(["a", "b", "c"])
41+
42+
result = bf_index.get_loc(key)
43+
44+
assert result == expected_loc
45+
assert isinstance(result, int)
46+
47+
48+
def test_get_loc_should_return_slice_for_monotonic_duplicates():
49+
"""Behavior: get_loc on a monotonic string index with duplicates returns a slice."""
50+
bf_index = bpd.Index(["a", "b", "b", "c"])
51+
pd_index = pd.Index(["a", "b", "b", "c"])
52+
53+
bf_result = bf_index.get_loc("b")
54+
pd_result = pd_index.get_loc("b")
55+
56+
assert isinstance(bf_result, slice)
57+
assert bf_result == pd_result # Should be slice(1, 3, None)
58+
59+
60+
def test_get_loc_should_return_slice_for_monotonic_numeric_duplicates():
61+
"""Behavior: get_loc on a monotonic numeric index with duplicates returns a slice."""
62+
bf_index = bpd.Index([1, 2, 2, 3])
63+
pd_index = pd.Index([1, 2, 2, 3])
64+
65+
bf_result = bf_index.get_loc(2)
66+
pd_result = pd_index.get_loc(2)
67+
68+
assert isinstance(bf_result, slice)
69+
assert bf_result == pd_result # Should be slice(1, 3, None)
70+
71+
72+
def test_get_loc_should_return_mask_for_non_monotonic_duplicates():
73+
"""Behavior: get_loc on a non-monotonic string index returns a boolean array."""
74+
bf_index = bpd.Index(["a", "b", "c", "b"])
75+
pd_index = pd.Index(["a", "b", "c", "b"])
76+
77+
pd_result = pd_index.get_loc("b")
78+
bf_result = bf_index.get_loc("b")
79+
80+
assert not isinstance(bf_result, (int, slice))
81+
82+
if hasattr(bf_result, "to_numpy"):
83+
bf_array = bf_result.to_numpy()
84+
else:
85+
bf_array = bf_result.to_pandas().to_numpy()
86+
numpy.testing.assert_array_equal(bf_array, pd_result)
87+
88+
89+
def test_get_loc_should_return_mask_for_non_monotonic_numeric_duplicates():
90+
"""Behavior: get_loc on a non-monotonic numeric index returns a boolean array."""
91+
bf_index = bpd.Index([1, 2, 3, 2])
92+
pd_index = pd.Index([1, 2, 3, 2])
93+
94+
pd_result = pd_index.get_loc(2)
95+
bf_result = bf_index.get_loc(2)
96+
97+
assert not isinstance(bf_result, (int, slice))
98+
99+
if hasattr(bf_result, "to_numpy"):
100+
bf_array = bf_result.to_numpy()
101+
else:
102+
bf_array = bf_result.to_pandas().to_numpy()
103+
numpy.testing.assert_array_equal(bf_array, pd_result)
104+
105+
106+
def test_get_loc_should_raise_error_for_missing_key():
107+
"""Behavior: get_loc raises KeyError when a string key is not found."""
108+
bf_index = bpd.Index(["a", "b", "c"])
109+
110+
with pytest.raises(KeyError):
111+
bf_index.get_loc("d")
112+
113+
114+
def test_get_loc_should_raise_error_for_missing_numeric_key():
115+
"""Behavior: get_loc raises KeyError when a numeric key is not found."""
116+
bf_index = bpd.Index([1, 2, 3])
117+
118+
with pytest.raises(KeyError):
119+
bf_index.get_loc(4)
120+
121+
122+
def test_get_loc_should_work_for_single_element_index():
123+
"""Behavior: get_loc on a single-element index returns 0."""
124+
assert bpd.Index(["a"]).get_loc("a") == pd.Index(["a"]).get_loc("a")
125+
126+
127+
def test_get_loc_should_return_slice_when_all_elements_are_duplicates():
128+
"""Behavior: get_loc returns a full slice if all elements match the key."""
129+
bf_index = bpd.Index(["a", "a", "a"])
130+
pd_index = pd.Index(["a", "a", "a"])
131+
132+
bf_result = bf_index.get_loc("a")
133+
pd_result = pd_index.get_loc("a")
134+
135+
assert isinstance(bf_result, slice)
136+
assert bf_result == pd_result # Should be slice(0, 3, None)
137+
138+
35139
def test_index_construct_from_series():
36140
bf_result = bpd.Index(
37141
bpd.Series([3, 14, 159], dtype=pd.Float64Dtype(), name="series_name"),

third_party/bigframes_vendored/pandas/core/indexes/base.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from collections.abc import Hashable
55
import typing
66

7+
import bigframes
78
from bigframes import constants
89

910

@@ -741,6 +742,47 @@ def argmin(self) -> int:
741742
"""
742743
raise NotImplementedError(constants.ABSTRACT_METHOD_ERROR_MESSAGE)
743744

745+
def get_loc(
746+
self, key: typing.Any
747+
) -> typing.Union[int, slice, bigframes.series.Series]:
748+
"""
749+
Get integer location, slice or boolean mask for requested label.
750+
751+
**Examples:**
752+
753+
>>> import bigframes.pandas as bpd
754+
>>> bpd.options.display.progress_bar = None
755+
756+
>>> unique_index = bpd.Index(list('abc'))
757+
>>> unique_index.get_loc('b')
758+
1
759+
760+
>>> monotonic_index = bpd.Index(list('abbc'))
761+
>>> monotonic_index.get_loc('b')
762+
slice(1, 3, None)
763+
764+
>>> non_monotonic_index = bpd.Index(list('abcb'))
765+
>>> non_monotonic_index.get_loc('b')
766+
0 False
767+
1 True
768+
2 False
769+
3 True
770+
Name: nan, dtype: boolean
771+
772+
Args:
773+
key: Label to get the location for.
774+
775+
Returns:
776+
Union[int, slice, bigframes.pandas.Series]:
777+
Integer position of the label for unique indexes.
778+
Slice object for monotonic indexes with duplicates.
779+
Boolean Series mask for non-monotonic indexes with duplicates.
780+
781+
Raises:
782+
KeyError: If the key is not found in the index.
783+
"""
784+
raise NotImplementedError(constants.ABSTRACT_METHOD_ERROR_MESSAGE)
785+
744786
def argmax(self) -> int:
745787
"""
746788
Return int position of the largest value in the Series.

0 commit comments

Comments
 (0)