Skip to content

Commit 6659355

Browse files
tswastchalmerlowe
andauthored
chore: add private _query_and_wait_bigframes method (#2250)
* chore: add private `_query_and_wait_bigframes` method Towards internal issue b/409104302 * fix unit tests * revert type hints * lint * Apply suggestions from code review Co-authored-by: Chalmer Lowe <[email protected]> * populate created, started, ended --------- Co-authored-by: Chalmer Lowe <[email protected]>
1 parent 0a95b24 commit 6659355

File tree

6 files changed

+619
-11
lines changed

6 files changed

+619
-11
lines changed

google/cloud/bigquery/_job_helpers.py

Lines changed: 155 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,17 +35,22 @@
3535
predicates where it is safe to generate a new query ID.
3636
"""
3737

38+
from __future__ import annotations
39+
3840
import copy
41+
import dataclasses
42+
import datetime
3943
import functools
4044
import uuid
4145
import textwrap
42-
from typing import Any, Dict, Optional, TYPE_CHECKING, Union
46+
from typing import Any, Callable, Dict, Optional, TYPE_CHECKING, Union
4347
import warnings
4448

4549
import google.api_core.exceptions as core_exceptions
4650
from google.api_core import retry as retries
4751

4852
from google.cloud.bigquery import job
53+
import google.cloud.bigquery.job.query
4954
import google.cloud.bigquery.query
5055
from google.cloud.bigquery import table
5156
import google.cloud.bigquery.retry
@@ -116,14 +121,21 @@ def query_jobs_insert(
116121
retry: Optional[retries.Retry],
117122
timeout: Optional[float],
118123
job_retry: Optional[retries.Retry],
124+
*,
125+
callback: Callable = lambda _: None,
119126
) -> job.QueryJob:
120127
"""Initiate a query using jobs.insert.
121128
122129
See: https://cloud.google.com/bigquery/docs/reference/rest/v2/jobs/insert
130+
131+
Args:
132+
callback (Callable):
133+
A callback function used by bigframes to report query progress.
123134
"""
124135
job_id_given = job_id is not None
125136
job_id_save = job_id
126137
job_config_save = job_config
138+
query_sent_factory = QuerySentEventFactory()
127139

128140
def do_query():
129141
# Make a copy now, so that original doesn't get changed by the process
@@ -136,6 +148,16 @@ def do_query():
136148

137149
try:
138150
query_job._begin(retry=retry, timeout=timeout)
151+
if job_config is not None and not job_config.dry_run:
152+
callback(
153+
query_sent_factory(
154+
query=query,
155+
billing_project=query_job.project,
156+
location=query_job.location,
157+
job_id=query_job.job_id,
158+
request_id=None,
159+
)
160+
)
139161
except core_exceptions.Conflict as create_exc:
140162
# The thought is if someone is providing their own job IDs and they get
141163
# their job ID generation wrong, this could end up returning results for
@@ -396,6 +418,7 @@ def query_and_wait(
396418
job_retry: Optional[retries.Retry],
397419
page_size: Optional[int] = None,
398420
max_results: Optional[int] = None,
421+
callback: Callable = lambda _: None,
399422
) -> table.RowIterator:
400423
"""Run the query, wait for it to finish, and return the results.
401424
@@ -415,9 +438,8 @@ def query_and_wait(
415438
location (Optional[str]):
416439
Location where to run the job. Must match the location of the
417440
table used in the query as well as the destination table.
418-
project (Optional[str]):
419-
Project ID of the project of where to run the job. Defaults
420-
to the client's project.
441+
project (str):
442+
Project ID of the project of where to run the job.
421443
api_timeout (Optional[float]):
422444
The number of seconds to wait for the underlying HTTP transport
423445
before using ``retry``.
@@ -441,6 +463,8 @@ def query_and_wait(
441463
request. Non-positive values are ignored.
442464
max_results (Optional[int]):
443465
The maximum total number of rows from this request.
466+
callback (Callable):
467+
A callback function used by bigframes to report query progress.
444468
445469
Returns:
446470
google.cloud.bigquery.table.RowIterator:
@@ -479,12 +503,14 @@ def query_and_wait(
479503
retry=retry,
480504
timeout=api_timeout,
481505
job_retry=job_retry,
506+
callback=callback,
482507
),
483508
api_timeout=api_timeout,
484509
wait_timeout=wait_timeout,
485510
retry=retry,
486511
page_size=page_size,
487512
max_results=max_results,
513+
callback=callback,
488514
)
489515

490516
path = _to_query_path(project)
@@ -496,10 +522,24 @@ def query_and_wait(
496522
if client.default_job_creation_mode:
497523
request_body["jobCreationMode"] = client.default_job_creation_mode
498524

525+
query_sent_factory = QuerySentEventFactory()
526+
499527
def do_query():
500-
request_body["requestId"] = make_job_id()
528+
request_id = make_job_id()
529+
request_body["requestId"] = request_id
501530
span_attributes = {"path": path}
502531

532+
if "dryRun" not in request_body:
533+
callback(
534+
query_sent_factory(
535+
query=query,
536+
billing_project=project,
537+
location=location,
538+
job_id=None,
539+
request_id=request_id,
540+
)
541+
)
542+
503543
# For easier testing, handle the retries ourselves.
504544
if retry is not None:
505545
response = retry(client._call_api)(
@@ -542,8 +582,25 @@ def do_query():
542582
retry=retry,
543583
page_size=page_size,
544584
max_results=max_results,
585+
callback=callback,
545586
)
546587

588+
if "dryRun" not in request_body:
589+
callback(
590+
QueryFinishedEvent(
591+
billing_project=project,
592+
location=query_results.location,
593+
query_id=query_results.query_id,
594+
job_id=query_results.job_id,
595+
total_rows=query_results.total_rows,
596+
total_bytes_processed=query_results.total_bytes_processed,
597+
slot_millis=query_results.slot_millis,
598+
destination=None,
599+
created=query_results.created,
600+
started=query_results.started,
601+
ended=query_results.ended,
602+
)
603+
)
547604
return table.RowIterator(
548605
client=client,
549606
api_request=functools.partial(client._call_api, retry, timeout=api_timeout),
@@ -614,19 +671,52 @@ def _wait_or_cancel(
614671
retry: Optional[retries.Retry],
615672
page_size: Optional[int],
616673
max_results: Optional[int],
674+
*,
675+
callback: Callable = lambda _: None,
617676
) -> table.RowIterator:
618677
"""Wait for a job to complete and return the results.
619678
620679
If we can't return the results within the ``wait_timeout``, try to cancel
621680
the job.
622681
"""
623682
try:
624-
return job.result(
683+
if not job.dry_run:
684+
callback(
685+
QueryReceivedEvent(
686+
billing_project=job.project,
687+
location=job.location,
688+
job_id=job.job_id,
689+
statement_type=job.statement_type,
690+
state=job.state,
691+
query_plan=job.query_plan,
692+
created=job.created,
693+
started=job.started,
694+
ended=job.ended,
695+
)
696+
)
697+
query_results = job.result(
625698
page_size=page_size,
626699
max_results=max_results,
627700
retry=retry,
628701
timeout=wait_timeout,
629702
)
703+
if not job.dry_run:
704+
callback(
705+
QueryFinishedEvent(
706+
billing_project=job.project,
707+
location=query_results.location,
708+
query_id=query_results.query_id,
709+
job_id=query_results.job_id,
710+
total_rows=query_results.total_rows,
711+
total_bytes_processed=query_results.total_bytes_processed,
712+
slot_millis=query_results.slot_millis,
713+
destination=job.destination,
714+
created=job.created,
715+
started=job.started,
716+
ended=job.ended,
717+
)
718+
)
719+
return query_results
630720
except Exception:
631721
# Attempt to cancel the job since we can't return the results.
632722
try:
@@ -635,3 +725,62 @@ def _wait_or_cancel(
635725
# Don't eat the original exception if cancel fails.
636726
pass
637727
raise
728+
729+
730+
@dataclasses.dataclass(frozen=True)
731+
class QueryFinishedEvent:
732+
"""Query finished successfully."""
733+
734+
billing_project: Optional[str]
735+
location: Optional[str]
736+
query_id: Optional[str]
737+
job_id: Optional[str]
738+
destination: Optional[table.TableReference]
739+
total_rows: Optional[int]
740+
total_bytes_processed: Optional[int]
741+
slot_millis: Optional[int]
742+
created: Optional[datetime.datetime]
743+
started: Optional[datetime.datetime]
744+
ended: Optional[datetime.datetime]
745+
746+
747+
@dataclasses.dataclass(frozen=True)
748+
class QueryReceivedEvent:
749+
"""Query received and acknowledged by the BigQuery API."""
750+
751+
billing_project: Optional[str]
752+
location: Optional[str]
753+
job_id: Optional[str]
754+
statement_type: Optional[str]
755+
state: Optional[str]
756+
query_plan: Optional[list[google.cloud.bigquery.job.query.QueryPlanEntry]]
757+
created: Optional[datetime.datetime]
758+
started: Optional[datetime.datetime]
759+
ended: Optional[datetime.datetime]
760+
761+
762+
@dataclasses.dataclass(frozen=True)
763+
class QuerySentEvent:
764+
"""Query sent to BigQuery."""
765+
766+
query: str
767+
billing_project: Optional[str]
768+
location: Optional[str]
769+
job_id: Optional[str]
770+
request_id: Optional[str]
771+
772+
773+
class QueryRetryEvent(QuerySentEvent):
774+
"""Query sent another time because the previous attempt failed."""
775+
776+
777+
class QuerySentEventFactory:
778+
"""Creates a QuerySentEvent first, then QueryRetryEvent after that."""
779+
780+
def __init__(self):
781+
self._event_constructor = QuerySentEvent
782+
783+
def __call__(self, **kwargs):
784+
result = self._event_constructor(**kwargs)
785+
self._event_constructor = QueryRetryEvent
786+
return result

google/cloud/bigquery/client.py

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
"""Client for interacting with the Google BigQuery API."""
1616

1717
from __future__ import absolute_import
18+
from __future__ import annotations
1819
from __future__ import division
1920

2021
from collections import abc as collections_abc
@@ -31,6 +32,7 @@
3132
import typing
3233
from typing import (
3334
Any,
35+
Callable,
3436
Dict,
3537
IO,
3638
Iterable,
@@ -3633,8 +3635,8 @@ def query_and_wait(
36333635
rate-limit-exceeded errors. Passing ``None`` disables
36343636
job retry. Not all jobs can be retried.
36353637
page_size (Optional[int]):
3636-
The maximum number of rows in each page of results from this
3637-
request. Non-positive values are ignored.
3638+
The maximum number of rows in each page of results from the
3639+
initial jobs.query request. Non-positive values are ignored.
36383640
max_results (Optional[int]):
36393641
The maximum total number of rows from this request.
36403642
@@ -3656,6 +3658,39 @@ def query_and_wait(
36563658
:class:`~google.cloud.bigquery.job.QueryJobConfig`
36573659
class.
36583660
"""
3661+
return self._query_and_wait_bigframes(
3662+
query,
3663+
job_config=job_config,
3664+
location=location,
3665+
project=project,
3666+
api_timeout=api_timeout,
3667+
wait_timeout=wait_timeout,
3668+
retry=retry,
3669+
job_retry=job_retry,
3670+
page_size=page_size,
3671+
max_results=max_results,
3672+
)
3673+
3674+
def _query_and_wait_bigframes(
3675+
self,
3676+
query,
3677+
*,
3678+
job_config: Optional[QueryJobConfig] = None,
3679+
location: Optional[str] = None,
3680+
project: Optional[str] = None,
3681+
api_timeout: TimeoutType = DEFAULT_TIMEOUT,
3682+
wait_timeout: Union[Optional[float], object] = POLLING_DEFAULT_VALUE,
3683+
retry: retries.Retry = DEFAULT_RETRY,
3684+
job_retry: retries.Retry = DEFAULT_JOB_RETRY,
3685+
page_size: Optional[int] = None,
3686+
max_results: Optional[int] = None,
3687+
callback: Callable = lambda _: None,
3688+
) -> RowIterator:
3689+
"""See query_and_wait.
3690+
3691+
This method has an extra callback parameter, which is used by bigframes
3692+
to create better progress bars.
3693+
"""
36593694
if project is None:
36603695
project = self.project
36613696

@@ -3681,6 +3716,7 @@ def query_and_wait(
36813716
job_retry=job_retry,
36823717
page_size=page_size,
36833718
max_results=max_results,
3719+
callback=callback,
36843720
)
36853721

36863722
def insert_rows(

google/cloud/bigquery/job/query.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1550,6 +1550,8 @@ def result( # type: ignore # (incompatible with supertype)
15501550
return _EmptyRowIterator(
15511551
project=self.project,
15521552
location=self.location,
1553+
schema=self.schema,
1554+
total_bytes_processed=self.total_bytes_processed,
15531555
# Intentionally omit job_id and query_id since this doesn't
15541556
# actually correspond to a finished query job.
15551557
)
@@ -1737,7 +1739,11 @@ def is_job_done():
17371739
project=self.project,
17381740
job_id=self.job_id,
17391741
query_id=self.query_id,
1742+
schema=self.schema,
17401743
num_dml_affected_rows=self._query_results.num_dml_affected_rows,
1744+
query=self.query,
1745+
total_bytes_processed=self.total_bytes_processed,
1746+
slot_millis=self.slot_millis,
17411747
)
17421748

17431749
# We know that there's at least 1 row, so only treat the response from

google/cloud/bigquery/query.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1228,11 +1228,18 @@ def location(self):
12281228
12291229
See:
12301230
https://cloud.google.com/bigquery/docs/reference/rest/v2/jobs/query#body.QueryResponse.FIELDS.job_reference
1231+
or https://cloud.google.com/bigquery/docs/reference/rest/v2/jobs/query#body.QueryResponse.FIELDS.location
12311232
12321233
Returns:
12331234
str: Job ID of the query job.
12341235
"""
1235-
return self._properties.get("jobReference", {}).get("location")
1236+
location = self._properties.get("jobReference", {}).get("location")
1237+
1238+
# Sometimes there's no job, but we still want to get the location
1239+
# information. Prefer the value from job for backwards compatibilitity.
1240+
if not location:
1241+
location = self._properties.get("location")
1242+
return location
12361243

12371244
@property
12381245
def query_id(self) -> Optional[str]:

0 commit comments

Comments
 (0)