Skip to content

Commit c67a25a

Browse files
authored
perf: Use promote_offsets for consistent row number generation for index.get_loc (#1957)
* perf: use promote_offsets for consistent row number generation for index.get_loc * remove unused import * work on lint
1 parent 770918e commit c67a25a

File tree

2 files changed

+14
-45
lines changed
  • bigframes/core/indexes
  • third_party/bigframes_vendored/pandas/core/indexes

2 files changed

+14
-45
lines changed

bigframes/core/indexes/base.py

Lines changed: 13 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -27,16 +27,12 @@
2727
import pandas
2828

2929
from bigframes import dtypes
30-
from bigframes.core.array_value import ArrayValue
3130
import bigframes.core.block_transforms as block_ops
3231
import bigframes.core.blocks as blocks
3332
import bigframes.core.expression as ex
34-
import bigframes.core.identifiers as ids
35-
import bigframes.core.nodes as nodes
3633
import bigframes.core.ordering as order
3734
import bigframes.core.utils as utils
3835
import bigframes.core.validations as validations
39-
import bigframes.core.window_spec as window_spec
4036
import bigframes.dtypes
4137
import bigframes.formatting_helpers as formatter
4238
import bigframes.operations as ops
@@ -272,37 +268,20 @@ def get_loc(self, key) -> typing.Union[int, slice, "bigframes.series.Series"]:
272268
# Get the index column from the block
273269
index_column = self._block.index_columns[0]
274270

275-
# Apply row numbering to the original data
276-
row_number_column_id = ids.ColumnId.unique()
277-
window_node = nodes.WindowOpNode(
278-
child=self._block._expr.node,
279-
expression=ex.NullaryAggregation(agg_ops.RowNumberOp()),
280-
window_spec=window_spec.unbound(),
281-
output_name=row_number_column_id,
282-
never_skip_nulls=True,
283-
)
284-
285-
windowed_array = ArrayValue(window_node)
286-
windowed_block = blocks.Block(
287-
windowed_array,
288-
index_columns=self._block.index_columns,
289-
column_labels=self._block.column_labels.insert(
290-
len(self._block.column_labels), None
291-
),
292-
index_labels=self._block._index_labels,
271+
# Use promote_offsets to get row numbers (similar to argmax/argmin implementation)
272+
block_with_offsets, offsets_id = self._block.promote_offsets(
273+
"temp_get_loc_offsets_"
293274
)
294275

295276
# Create expression to find matching positions
296277
match_expr = ops.eq_op.as_expr(ex.deref(index_column), ex.const(key))
297-
windowed_block, match_col_id = windowed_block.project_expr(match_expr)
278+
block_with_offsets, match_col_id = block_with_offsets.project_expr(match_expr)
298279

299280
# Filter to only rows where the key matches
300-
filtered_block = windowed_block.filter_by_id(match_col_id)
281+
filtered_block = block_with_offsets.filter_by_id(match_col_id)
301282

302-
# Check if key exists at all by counting on the filtered block
303-
count_agg = ex.UnaryAggregation(
304-
agg_ops.count_op, ex.deref(row_number_column_id.name)
305-
)
283+
# Check if key exists at all by counting
284+
count_agg = ex.UnaryAggregation(agg_ops.count_op, ex.deref(offsets_id))
306285
count_result = filtered_block._expr.aggregate([(count_agg, "count")])
307286
count_scalar = self._block.session._executor.execute(
308287
count_result
@@ -313,9 +292,7 @@ def get_loc(self, key) -> typing.Union[int, slice, "bigframes.series.Series"]:
313292

314293
# If only one match, return integer position
315294
if count_scalar == 1:
316-
min_agg = ex.UnaryAggregation(
317-
agg_ops.min_op, ex.deref(row_number_column_id.name)
318-
)
295+
min_agg = ex.UnaryAggregation(agg_ops.min_op, ex.deref(offsets_id))
319296
position_result = filtered_block._expr.aggregate([(min_agg, "position")])
320297
position_scalar = self._block.session._executor.execute(
321298
position_result
@@ -325,32 +302,24 @@ def get_loc(self, key) -> typing.Union[int, slice, "bigframes.series.Series"]:
325302
# Handle multiple matches based on index monotonicity
326303
is_monotonic = self.is_monotonic_increasing or self.is_monotonic_decreasing
327304
if is_monotonic:
328-
return self._get_monotonic_slice(filtered_block, row_number_column_id)
305+
return self._get_monotonic_slice(filtered_block, offsets_id)
329306
else:
330307
# Return boolean mask for non-monotonic duplicates
331-
mask_block = windowed_block.select_columns([match_col_id])
332-
# Reset the index to use positional integers instead of original index values
308+
mask_block = block_with_offsets.select_columns([match_col_id])
333309
mask_block = mask_block.reset_index(drop=True)
334-
# Ensure correct dtype and name to match pandas behavior
335310
result_series = bigframes.series.Series(mask_block)
336311
return result_series.astype("boolean")
337312

338-
def _get_monotonic_slice(
339-
self, filtered_block, row_number_column_id: "ids.ColumnId"
340-
) -> slice:
313+
def _get_monotonic_slice(self, filtered_block, offsets_id: str) -> slice:
341314
"""Helper method to get a slice for monotonic duplicates with an optimized query."""
342315
# Combine min and max aggregations into a single query for efficiency
343316
min_max_aggs = [
344317
(
345-
ex.UnaryAggregation(
346-
agg_ops.min_op, ex.deref(row_number_column_id.name)
347-
),
318+
ex.UnaryAggregation(agg_ops.min_op, ex.deref(offsets_id)),
348319
"min_pos",
349320
),
350321
(
351-
ex.UnaryAggregation(
352-
agg_ops.max_op, ex.deref(row_number_column_id.name)
353-
),
322+
ex.UnaryAggregation(agg_ops.max_op, ex.deref(offsets_id)),
354323
"max_pos",
355324
),
356325
]

third_party/bigframes_vendored/pandas/core/indexes/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -767,7 +767,7 @@ def get_loc(
767767
1 True
768768
2 False
769769
3 True
770-
Name: nan, dtype: boolean
770+
dtype: boolean
771771
772772
Args:
773773
key: Label to get the location for.

0 commit comments

Comments
 (0)