Skip to content

Commit a7963fe

Browse files
authored
feat!: add allow_large_results option to read_gbq_query, aligning with bpd.options.compute.allow_large_results option (#1935)
Release-As: 2.18.0
1 parent 8689199 commit a7963fe

File tree

18 files changed

+529
-178
lines changed

18 files changed

+529
-178
lines changed

bigframes/bigquery/_operations/search.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ def vector_search(
9999
distance_type: Optional[Literal["euclidean", "cosine", "dot_product"]] = None,
100100
fraction_lists_to_search: Optional[float] = None,
101101
use_brute_force: Optional[bool] = None,
102+
allow_large_results: Optional[bool] = None,
102103
) -> dataframe.DataFrame:
103104
"""
104105
Conduct vector search which searches embeddings to find semantically similar entities.
@@ -163,12 +164,12 @@ def vector_search(
163164
... query=search_query,
164165
... distance_type="cosine",
165166
... query_column_to_search="another_embedding",
166-
... top_k=2)
167+
... top_k=2).sort_values("id")
167168
query_id embedding another_embedding id my_embedding distance
168-
1 cat [3. 5.2] [3.3 5.2] 2 [2. 4.] 0.005181
169-
0 dog [1. 2.] [0.7 2.2] 4 [1. 3.2] 0.000013
170169
1 cat [3. 5.2] [3.3 5.2] 1 [1. 2.] 0.005181
170+
1 cat [3. 5.2] [3.3 5.2] 2 [2. 4.] 0.005181
171171
0 dog [1. 2.] [0.7 2.2] 3 [1.5 7. ] 0.004697
172+
0 dog [1. 2.] [0.7 2.2] 4 [1. 3.2] 0.000013
172173
<BLANKLINE>
173174
[4 rows x 6 columns]
174175
@@ -199,6 +200,10 @@ def vector_search(
199200
use_brute_force (bool):
200201
Determines whether to use brute force search by skipping the vector index if one is available.
201202
Default to False.
203+
allow_large_results (bool, optional):
204+
Whether to allow large query results. If ``True``, the query
205+
results can be larger than the maximum response size.
206+
Defaults to ``bpd.options.compute.allow_large_results``.
202207
203208
Returns:
204209
bigframes.dataframe.DataFrame: A DataFrame containing vector search result.
@@ -236,9 +241,11 @@ def vector_search(
236241
options=options,
237242
)
238243
if index_col_ids is not None:
239-
df = query._session.read_gbq(sql, index_col=index_col_ids)
244+
df = query._session.read_gbq_query(
245+
sql, index_col=index_col_ids, allow_large_results=allow_large_results
246+
)
240247
df.index.names = index_labels
241248
else:
242-
df = query._session.read_gbq(sql)
249+
df = query._session.read_gbq_query(sql, allow_large_results=allow_large_results)
243250

244251
return df

bigframes/dataframe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4496,7 +4496,7 @@ def to_dict(
44964496
allow_large_results: Optional[bool] = None,
44974497
**kwargs,
44984498
) -> dict | list[dict]:
4499-
return self.to_pandas(allow_large_results=allow_large_results).to_dict(orient, into, **kwargs) # type: ignore
4499+
return self.to_pandas(allow_large_results=allow_large_results).to_dict(orient=orient, into=into, **kwargs) # type: ignore
45004500

45014501
def to_excel(
45024502
self,

bigframes/ml/core.py

Lines changed: 76 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,11 @@ def ai_forecast(
4545
result_sql = self._sql_generator.ai_forecast(
4646
source_sql=input_data.sql, options=options
4747
)
48-
return self._session.read_gbq(result_sql)
48+
49+
# TODO(b/395912450): Once the limitations with local data are
50+
# resolved, consider setting allow_large_results only when expected
51+
# data size is large.
52+
return self._session.read_gbq_query(result_sql, allow_large_results=True)
4953

5054

5155
class BqmlModel(BaseBqml):
@@ -95,7 +99,17 @@ def _apply_ml_tvf(
9599
)
96100

97101
result_sql = apply_sql_tvf(input_sql)
98-
df = self._session.read_gbq(result_sql, index_col=index_col_ids)
102+
df = self._session.read_gbq_query(
103+
result_sql,
104+
index_col=index_col_ids,
105+
# Many ML methods use nested JSON, which isn't yet compatible with
106+
# joining local results. Also, there is a chance that the results
107+
# are greater than 10 GB.
108+
# TODO(b/395912450): Once the limitations with local data are
109+
# resolved, consider setting allow_large_results only when expected
110+
# data size is large.
111+
allow_large_results=True,
112+
)
99113
if df._has_index:
100114
df.index.names = index_labels
101115
# Restore column labels
@@ -159,7 +173,10 @@ def explain_predict(
159173
def global_explain(self, options: Mapping[str, bool]) -> bpd.DataFrame:
160174
sql = self._sql_generator.ml_global_explain(struct_options=options)
161175
return (
162-
self._session.read_gbq(sql)
176+
# TODO(b/395912450): Once the limitations with local data are
177+
# resolved, consider setting allow_large_results only when expected
178+
# data size is large.
179+
self._session.read_gbq_query(sql, allow_large_results=True)
163180
.sort_values(by="attribution", ascending=False)
164181
.set_index("feature")
165182
)
@@ -234,26 +251,49 @@ def forecast(self, options: Mapping[str, int | float]) -> bpd.DataFrame:
234251
sql = self._sql_generator.ml_forecast(struct_options=options)
235252
timestamp_col_name = "forecast_timestamp"
236253
index_cols = [timestamp_col_name]
237-
first_col_name = self._session.read_gbq(sql).columns.values[0]
254+
# TODO(b/395912450): Once the limitations with local data are
255+
# resolved, consider setting allow_large_results only when expected
256+
# data size is large.
257+
first_col_name = self._session.read_gbq_query(
258+
sql, allow_large_results=True
259+
).columns.values[0]
238260
if timestamp_col_name != first_col_name:
239261
index_cols.append(first_col_name)
240-
return self._session.read_gbq(sql, index_col=index_cols).reset_index()
262+
# TODO(b/395912450): Once the limitations with local data are
263+
# resolved, consider setting allow_large_results only when expected
264+
# data size is large.
265+
return self._session.read_gbq_query(
266+
sql, index_col=index_cols, allow_large_results=True
267+
).reset_index()
241268

242269
def explain_forecast(self, options: Mapping[str, int | float]) -> bpd.DataFrame:
243270
sql = self._sql_generator.ml_explain_forecast(struct_options=options)
244271
timestamp_col_name = "time_series_timestamp"
245272
index_cols = [timestamp_col_name]
246-
first_col_name = self._session.read_gbq(sql).columns.values[0]
273+
# TODO(b/395912450): Once the limitations with local data are
274+
# resolved, consider setting allow_large_results only when expected
275+
# data size is large.
276+
first_col_name = self._session.read_gbq_query(
277+
sql, allow_large_results=True
278+
).columns.values[0]
247279
if timestamp_col_name != first_col_name:
248280
index_cols.append(first_col_name)
249-
return self._session.read_gbq(sql, index_col=index_cols).reset_index()
281+
# TODO(b/395912450): Once the limitations with local data are
282+
# resolved, consider setting allow_large_results only when expected
283+
# data size is large.
284+
return self._session.read_gbq_query(
285+
sql, index_col=index_cols, allow_large_results=True
286+
).reset_index()
250287

251288
def evaluate(self, input_data: Optional[bpd.DataFrame] = None):
252289
sql = self._sql_generator.ml_evaluate(
253290
input_data.sql if (input_data is not None) else None
254291
)
255292

256-
return self._session.read_gbq(sql)
293+
# TODO(b/395912450): Once the limitations with local data are
294+
# resolved, consider setting allow_large_results only when expected
295+
# data size is large.
296+
return self._session.read_gbq_query(sql, allow_large_results=True)
257297

258298
def llm_evaluate(
259299
self,
@@ -262,42 +302,62 @@ def llm_evaluate(
262302
):
263303
sql = self._sql_generator.ml_llm_evaluate(input_data.sql, task_type)
264304

265-
return self._session.read_gbq(sql)
305+
# TODO(b/395912450): Once the limitations with local data are
306+
# resolved, consider setting allow_large_results only when expected
307+
# data size is large.
308+
return self._session.read_gbq_query(sql, allow_large_results=True)
266309

267310
def arima_evaluate(self, show_all_candidate_models: bool = False):
268311
sql = self._sql_generator.ml_arima_evaluate(show_all_candidate_models)
269312

270-
return self._session.read_gbq(sql)
313+
# TODO(b/395912450): Once the limitations with local data are
314+
# resolved, consider setting allow_large_results only when expected
315+
# data size is large.
316+
return self._session.read_gbq_query(sql, allow_large_results=True)
271317

272318
def arima_coefficients(self) -> bpd.DataFrame:
273319
sql = self._sql_generator.ml_arima_coefficients()
274320

275-
return self._session.read_gbq(sql)
321+
# TODO(b/395912450): Once the limitations with local data are
322+
# resolved, consider setting allow_large_results only when expected
323+
# data size is large.
324+
return self._session.read_gbq_query(sql, allow_large_results=True)
276325

277326
def centroids(self) -> bpd.DataFrame:
278327
assert self._model.model_type == "KMEANS"
279328

280329
sql = self._sql_generator.ml_centroids()
281330

282-
return self._session.read_gbq(
283-
sql, index_col=["centroid_id", "feature"]
331+
# TODO(b/395912450): Once the limitations with local data are
332+
# resolved, consider setting allow_large_results only when expected
333+
# data size is large.
334+
return self._session.read_gbq_query(
335+
sql, index_col=["centroid_id", "feature"], allow_large_results=True
284336
).reset_index()
285337

286338
def principal_components(self) -> bpd.DataFrame:
287339
assert self._model.model_type == "PCA"
288340

289341
sql = self._sql_generator.ml_principal_components()
290342

291-
return self._session.read_gbq(
292-
sql, index_col=["principal_component_id", "feature"]
343+
# TODO(b/395912450): Once the limitations with local data are
344+
# resolved, consider setting allow_large_results only when expected
345+
# data size is large.
346+
return self._session.read_gbq_query(
347+
sql,
348+
index_col=["principal_component_id", "feature"],
349+
allow_large_results=True,
293350
).reset_index()
294351

295352
def principal_component_info(self) -> bpd.DataFrame:
296353
assert self._model.model_type == "PCA"
297354

298355
sql = self._sql_generator.ml_principal_component_info()
299356

300-
return self._session.read_gbq(sql)
357+
# TODO(b/395912450): Once the limitations with local data are
358+
# resolved, consider setting allow_large_results only when expected
359+
# data size is large.
360+
return self._session.read_gbq_query(sql, allow_large_results=True)
301361

302362
def copy(self, new_model_name: str, replace: bool = False) -> BqmlModel:
303363
job_config = self._session._prepare_copy_job_config()

bigframes/operations/ai.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -566,6 +566,10 @@ def search(
566566
column_to_search=embedding_result_column,
567567
query=query_df,
568568
top_k=top_k,
569+
# TODO(tswast): set allow_large_results based on Series size.
570+
# If we expect small results, it could be faster to set
571+
# allow_large_results to False.
572+
allow_large_results=True,
569573
)
570574
.rename(columns={"content": search_column})
571575
.set_index("index")

bigframes/pandas/io/api.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,7 @@ def read_gbq( # type: ignore[overload-overlap]
187187
use_cache: Optional[bool] = ...,
188188
col_order: Iterable[str] = ...,
189189
dry_run: Literal[False] = ...,
190+
allow_large_results: Optional[bool] = ...,
190191
) -> bigframes.dataframe.DataFrame:
191192
...
192193

@@ -203,6 +204,7 @@ def read_gbq(
203204
use_cache: Optional[bool] = ...,
204205
col_order: Iterable[str] = ...,
205206
dry_run: Literal[True] = ...,
207+
allow_large_results: Optional[bool] = ...,
206208
) -> pandas.Series:
207209
...
208210

@@ -218,6 +220,7 @@ def read_gbq(
218220
use_cache: Optional[bool] = None,
219221
col_order: Iterable[str] = (),
220222
dry_run: bool = False,
223+
allow_large_results: Optional[bool] = None,
221224
) -> bigframes.dataframe.DataFrame | pandas.Series:
222225
_set_default_session_location_if_possible(query_or_table)
223226
return global_session.with_default_session(
@@ -231,6 +234,7 @@ def read_gbq(
231234
use_cache=use_cache,
232235
col_order=col_order,
233236
dry_run=dry_run,
237+
allow_large_results=allow_large_results,
234238
)
235239

236240

@@ -400,6 +404,7 @@ def read_gbq_query( # type: ignore[overload-overlap]
400404
col_order: Iterable[str] = ...,
401405
filters: vendored_pandas_gbq.FiltersType = ...,
402406
dry_run: Literal[False] = ...,
407+
allow_large_results: Optional[bool] = ...,
403408
) -> bigframes.dataframe.DataFrame:
404409
...
405410

@@ -416,6 +421,7 @@ def read_gbq_query(
416421
col_order: Iterable[str] = ...,
417422
filters: vendored_pandas_gbq.FiltersType = ...,
418423
dry_run: Literal[True] = ...,
424+
allow_large_results: Optional[bool] = ...,
419425
) -> pandas.Series:
420426
...
421427

@@ -431,6 +437,7 @@ def read_gbq_query(
431437
col_order: Iterable[str] = (),
432438
filters: vendored_pandas_gbq.FiltersType = (),
433439
dry_run: bool = False,
440+
allow_large_results: Optional[bool] = None,
434441
) -> bigframes.dataframe.DataFrame | pandas.Series:
435442
_set_default_session_location_if_possible(query)
436443
return global_session.with_default_session(
@@ -444,6 +451,7 @@ def read_gbq_query(
444451
col_order=col_order,
445452
filters=filters,
446453
dry_run=dry_run,
454+
allow_large_results=allow_large_results,
447455
)
448456

449457

@@ -617,7 +625,11 @@ def from_glob_path(
617625

618626

619627
def _get_bqclient() -> bigquery.Client:
620-
clients_provider = bigframes.session.clients.ClientsProvider(
628+
# Address circular imports in doctest due to bigframes/session/__init__.py
629+
# containing a lot of logic and samples.
630+
from bigframes.session import clients
631+
632+
clients_provider = clients.ClientsProvider(
621633
project=config.options.bigquery.project,
622634
location=config.options.bigquery.location,
623635
use_regional_endpoints=config.options.bigquery.use_regional_endpoints,
@@ -631,11 +643,15 @@ def _get_bqclient() -> bigquery.Client:
631643

632644

633645
def _dry_run(query, bqclient) -> bigquery.QueryJob:
646+
# Address circular imports in doctest due to bigframes/session/__init__.py
647+
# containing a lot of logic and samples.
648+
from bigframes.session import metrics as bf_metrics
649+
634650
job = bqclient.query(query, bigquery.QueryJobConfig(dry_run=True))
635651

636652
# Fix for b/435183833. Log metrics even if a Session isn't available.
637-
if bigframes.session.metrics.LOGGING_NAME_ENV_VAR in os.environ:
638-
metrics = bigframes.session.metrics.ExecutionMetrics()
653+
if bf_metrics.LOGGING_NAME_ENV_VAR in os.environ:
654+
metrics = bf_metrics.ExecutionMetrics()
639655
metrics.count_job_stats(job)
640656
return job
641657

@@ -645,6 +661,10 @@ def _set_default_session_location_if_possible(query):
645661

646662

647663
def _set_default_session_location_if_possible_deferred_query(create_query):
664+
# Address circular imports in doctest due to bigframes/session/__init__.py
665+
# containing a lot of logic and samples.
666+
from bigframes.session._io import bigquery
667+
648668
# Set the location as per the query if this is the first query the user is
649669
# running and:
650670
# (1) Default session has not started yet, and
@@ -666,7 +686,7 @@ def _set_default_session_location_if_possible_deferred_query(create_query):
666686
query = create_query()
667687
bqclient = _get_bqclient()
668688

669-
if bigframes.session._io.bigquery.is_query(query):
689+
if bigquery.is_query(query):
670690
# Intentionally run outside of the session so that we can detect the
671691
# location before creating the session. Since it's a dry_run, labels
672692
# aren't necessary.

0 commit comments

Comments
 (0)