Skip to content

Commit 3263bd7

Browse files
authored
feat: add bigframes.bigquery.approx_top_count (#1010)
* feat: add bigframes.bigquery.approx_top_count * fix docs
1 parent 2fe5e48 commit 3263bd7

File tree

4 files changed

+165
-0
lines changed

4 files changed

+165
-0
lines changed

bigframes/bigquery/__init__.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,46 @@ def json_extract_array(
272272
return series._apply_unary_op(ops.JSONExtractArray(json_path=json_path))
273273

274274

275+
# Approximate aggrgate functions defined from
276+
# https://cloud.google.com/bigquery/docs/reference/standard-sql/approximate_aggregate_functions
277+
278+
279+
def approx_top_count(
280+
series: series.Series,
281+
number: int,
282+
) -> series.Series:
283+
"""Returns the approximate top elements of `expression` as an array of STRUCTs.
284+
The number parameter specifies the number of elements returned.
285+
286+
Each `STRUCT` contains two fields. The first field (named `value`) contains an input
287+
value. The second field (named `count`) contains an `INT64` specifying the number
288+
of times the value was returned.
289+
290+
Returns `NULL` if there are zero input rows.
291+
292+
**Examples:**
293+
294+
>>> import bigframes.pandas as bpd
295+
>>> import bigframes.bigquery as bbq
296+
>>> bpd.options.display.progress_bar = None
297+
>>> s = bpd.Series(["apple", "apple", "pear", "pear", "pear", "banana"])
298+
>>> bbq.approx_top_count(s, number=2)
299+
[{'value': 'pear', 'count': 3}, {'value': 'apple', 'count': 2}]
300+
301+
Args:
302+
series (bigframes.series.Series):
303+
The Series with any data type that the `GROUP BY` clause supports.
304+
number (int):
305+
An integer specifying the number of times the value was returned.
306+
307+
Returns:
308+
bigframes.series.Series: A new Series with the result data.
309+
"""
310+
if number < 1:
311+
raise ValueError("The number of approx_top_count must be at least 1")
312+
return series._apply_aggregation(agg_ops.ApproxTopCountOp(number=number))
313+
314+
275315
def struct(value: dataframe.DataFrame) -> series.Series:
276316
"""Takes a DataFrame and converts it into a Series of structs with each
277317
struct entry corresponding to a DataFrame row and each struct field

bigframes/core/compile/aggregate_compiler.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
15+
from __future__ import annotations
16+
1417
import functools
1518
import typing
1619
from typing import cast, List, Optional
@@ -19,6 +22,7 @@
1922
import bigframes_vendored.ibis.expr.operations as vendored_ibis_ops
2023
import ibis
2124
import ibis.expr.datatypes as ibis_dtypes
25+
import ibis.expr.operations as ibis_ops
2226
import ibis.expr.types as ibis_types
2327
import pandas as pd
2428

@@ -196,6 +200,34 @@ def _(
196200
return cast(ibis_types.NumericValue, value)
197201

198202

203+
@compile_unary_agg.register
204+
def _(
205+
op: agg_ops.ApproxTopCountOp,
206+
column: ibis_types.Column,
207+
window=None,
208+
) -> ibis_types.ArrayColumn:
209+
# APPROX_TOP_COUNT has very few allowed windows.
210+
if window is not None:
211+
raise NotImplementedError(
212+
f"Approx top count with windowing is not supported. {constants.FEEDBACK_LINK}"
213+
)
214+
215+
# Define a user-defined function (UDF) that approximates the top counts of an expression.
216+
# The type of value is dynamically matching the input column.
217+
def approx_top_count(expression, number: ibis_dtypes.int64): # type: ignore
218+
...
219+
220+
return_type = ibis_dtypes.Array(
221+
ibis_dtypes.Struct.from_tuples(
222+
[("value", column.type()), ("count", ibis_dtypes.int64)]
223+
)
224+
)
225+
approx_top_count.__annotations__["return"] = return_type
226+
udf_op = ibis_ops.udf.agg.builtin(approx_top_count)
227+
228+
return udf_op(expression=column, number=op.number)
229+
230+
199231
@compile_unary_agg.register
200232
@numeric_op
201233
def _(

bigframes/operations/aggregations.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,23 @@ def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionT
184184
return input_types[0]
185185

186186

187+
@dataclasses.dataclass(frozen=True)
188+
class ApproxTopCountOp(UnaryAggregateOp):
189+
name: typing.ClassVar[str] = "approx_top_count"
190+
number: int
191+
192+
def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType:
193+
if not dtypes.is_orderable(input_types[0]):
194+
raise TypeError(f"Type {input_types[0]} is not orderable")
195+
196+
input_type = input_types[0]
197+
fields = [
198+
pa.field("value", dtypes.bigframes_dtype_to_arrow_dtype(input_type)),
199+
pa.field("count", pa.int64()),
200+
]
201+
return pd.ArrowDtype(pa.list_(pa.struct(fields)))
202+
203+
187204
@dataclasses.dataclass(frozen=True)
188205
class MeanOp(UnaryAggregateOp):
189206
name: ClassVar[str] = "mean"
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import pytest
16+
17+
import bigframes.bigquery as bbq
18+
import bigframes.pandas as bpd
19+
20+
21+
@pytest.mark.parametrize(
22+
("data", "expected"),
23+
[
24+
pytest.param(
25+
[1, 2, 3, 3, 2], [{"value": 3, "count": 2}, {"value": 2, "count": 2}]
26+
),
27+
pytest.param(
28+
["apple", "apple", "pear", "pear", "pear", "banana"],
29+
[{"value": "pear", "count": 3}, {"value": "apple", "count": 2}],
30+
),
31+
pytest.param(
32+
[True, False, True, False, True],
33+
[{"value": True, "count": 3}, {"value": False, "count": 2}],
34+
),
35+
pytest.param(
36+
[],
37+
[],
38+
),
39+
pytest.param(
40+
[[1, 2], [1], [1, 2]],
41+
[],
42+
marks=pytest.mark.xfail(raises=TypeError),
43+
),
44+
],
45+
ids=["int64", "string", "bool", "null", "array"],
46+
)
47+
def test_approx_top_count_w_dtypes(data, expected):
48+
s = bpd.Series(data)
49+
result = bbq.approx_top_count(s, number=2)
50+
assert result == expected
51+
52+
53+
@pytest.mark.parametrize(
54+
("number", "expected"),
55+
[
56+
pytest.param(
57+
0,
58+
[],
59+
marks=pytest.mark.xfail(raises=ValueError),
60+
),
61+
pytest.param(1, [{"value": 3, "count": 2}]),
62+
pytest.param(
63+
4,
64+
[
65+
{"value": 3, "count": 2},
66+
{"value": 2, "count": 2},
67+
{"value": 1, "count": 1},
68+
],
69+
),
70+
],
71+
ids=["zero", "one", "full"],
72+
)
73+
def test_approx_top_count_w_numbers(number, expected):
74+
s = bpd.Series([1, 2, 3, 3, 2])
75+
result = bbq.approx_top_count(s, number=number)
76+
assert result == expected

0 commit comments

Comments
 (0)