Skip to content

Commit de668d6

Browse files
committed
chore: add private _query_and_wait_bigframes method
Towards internal issue b/409104302
1 parent 3deff1d commit de668d6

File tree

5 files changed

+402
-15
lines changed

5 files changed

+402
-15
lines changed

google/cloud/bigquery/_job_helpers.py

Lines changed: 132 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,17 +35,21 @@
3535
predicates where it is safe to generate a new query ID.
3636
"""
3737

38+
from __future__ import annotations
39+
3840
import copy
41+
import dataclasses
3942
import functools
4043
import uuid
4144
import textwrap
42-
from typing import Any, Dict, Optional, TYPE_CHECKING, Union
45+
from typing import Any, Callable, Dict, Optional, TYPE_CHECKING, Union
4346
import warnings
4447

4548
import google.api_core.exceptions as core_exceptions
4649
from google.api_core import retry as retries
4750

4851
from google.cloud.bigquery import job
52+
import google.cloud.bigquery.job.query
4953
import google.cloud.bigquery.query
5054
from google.cloud.bigquery import table
5155
import google.cloud.bigquery.retry
@@ -116,14 +120,20 @@ def query_jobs_insert(
116120
retry: Optional[retries.Retry],
117121
timeout: Optional[float],
118122
job_retry: Optional[retries.Retry],
123+
callback: Callable,
119124
) -> job.QueryJob:
120125
"""Initiate a query using jobs.insert.
121126
122127
See: https://cloud.google.com/bigquery/docs/reference/rest/v2/jobs/insert
128+
129+
Args:
130+
callback (Callable):
131+
A callback function used by bigframes to report query progress.
123132
"""
124133
job_id_given = job_id is not None
125134
job_id_save = job_id
126135
job_config_save = job_config
136+
query_sent_factory = QuerySentEventFactory()
127137

128138
def do_query():
129139
# Make a copy now, so that original doesn't get changed by the process
@@ -136,6 +146,15 @@ def do_query():
136146

137147
try:
138148
query_job._begin(retry=retry, timeout=timeout)
149+
callback(
150+
query_sent_factory(
151+
query=query,
152+
billing_project=query_job.project,
153+
location=query_job.location,
154+
job_id=query_job.job_id,
155+
request_id=None,
156+
)
157+
)
139158
except core_exceptions.Conflict as create_exc:
140159
# The thought is if someone is providing their own job IDs and they get
141160
# their job ID generation wrong, this could end up returning results for
@@ -396,6 +415,7 @@ def query_and_wait(
396415
job_retry: Optional[retries.Retry],
397416
page_size: Optional[int] = None,
398417
max_results: Optional[int] = None,
418+
callback: Callable = lambda _: None,
399419
) -> table.RowIterator:
400420
"""Run the query, wait for it to finish, and return the results.
401421
@@ -415,9 +435,8 @@ def query_and_wait(
415435
location (Optional[str]):
416436
Location where to run the job. Must match the location of the
417437
table used in the query as well as the destination table.
418-
project (Optional[str]):
419-
Project ID of the project of where to run the job. Defaults
420-
to the client's project.
438+
project (str):
439+
Project ID of the project of where to run the job.
421440
api_timeout (Optional[float]):
422441
The number of seconds to wait for the underlying HTTP transport
423442
before using ``retry``.
@@ -441,6 +460,8 @@ def query_and_wait(
441460
request. Non-positive values are ignored.
442461
max_results (Optional[int]):
443462
The maximum total number of rows from this request.
463+
callback (Callable):
464+
A callback function used by bigframes to report query progress.
444465
445466
Returns:
446467
google.cloud.bigquery.table.RowIterator:
@@ -479,12 +500,14 @@ def query_and_wait(
479500
retry=retry,
480501
timeout=api_timeout,
481502
job_retry=job_retry,
503+
callback=callback,
482504
),
483505
api_timeout=api_timeout,
484506
wait_timeout=wait_timeout,
485507
retry=retry,
486508
page_size=page_size,
487509
max_results=max_results,
510+
callback=callback,
488511
)
489512

490513
path = _to_query_path(project)
@@ -496,10 +519,23 @@ def query_and_wait(
496519
if client.default_job_creation_mode:
497520
request_body["jobCreationMode"] = client.default_job_creation_mode
498521

522+
query_sent_factory = QuerySentEventFactory()
523+
499524
def do_query():
500-
request_body["requestId"] = make_job_id()
525+
request_id = make_job_id()
526+
request_body["requestId"] = request_id
501527
span_attributes = {"path": path}
502528

529+
callback(
530+
query_sent_factory(
531+
query=query,
532+
billing_project=project,
533+
location=location,
534+
job_id=None,
535+
request_id=request_id,
536+
)
537+
)
538+
503539
# For easier testing, handle the retries ourselves.
504540
if retry is not None:
505541
response = retry(client._call_api)(
@@ -542,8 +578,21 @@ def do_query():
542578
retry=retry,
543579
page_size=page_size,
544580
max_results=max_results,
581+
callback=callback,
545582
)
546583

584+
callback(
585+
QueryFinishedEvent(
586+
billing_project=project,
587+
location=query_results.location,
588+
query_id=query_results.query_id,
589+
job_id=query_results.job_id,
590+
total_rows=query_results.total_rows,
591+
total_bytes_processed=query_results.total_bytes_processed,
592+
slot_millis=query_results.slot_millis,
593+
destination=None,
594+
)
595+
)
547596
return table.RowIterator(
548597
client=client,
549598
api_request=functools.partial(client._call_api, retry, timeout=api_timeout),
@@ -611,19 +660,43 @@ def _wait_or_cancel(
611660
retry: Optional[retries.Retry],
612661
page_size: Optional[int],
613662
max_results: Optional[int],
663+
callback: Callable,
614664
) -> table.RowIterator:
615665
"""Wait for a job to complete and return the results.
616666
617667
If we can't return the results within the ``wait_timeout``, try to cancel
618668
the job.
619669
"""
620670
try:
621-
return job.result(
671+
callback(
672+
QueryReceivedEvent(
673+
billing_project=job.project,
674+
location=job.location,
675+
job_id=job.job_id,
676+
statement_type=job.statement_type,
677+
state=job.state,
678+
query_plan=job.query_plan,
679+
)
680+
)
681+
query_results = job.result(
622682
page_size=page_size,
623683
max_results=max_results,
624684
retry=retry,
625685
timeout=wait_timeout,
626686
)
687+
callback(
688+
QueryFinishedEvent(
689+
billing_project=job.project,
690+
location=query_results.location,
691+
query_id=query_results.query_id,
692+
job_id=query_results.job_id,
693+
total_rows=query_results.total_rows,
694+
total_bytes_processed=query_results.total_bytes_processed,
695+
slot_millis=query_results.slot_millis,
696+
destination=job.destination,
697+
)
698+
)
699+
return query_results
627700
except Exception:
628701
# Attempt to cancel the job since we can't return the results.
629702
try:
@@ -632,3 +705,56 @@ def _wait_or_cancel(
632705
# Don't eat the original exception if cancel fails.
633706
pass
634707
raise
708+
709+
710+
@dataclasses.dataclass(frozen=True)
711+
class QueryFinishedEvent:
712+
"""Query finished successfully."""
713+
714+
billing_project: Optional[str]
715+
location: Optional[str]
716+
query_id: Optional[str]
717+
job_id: Optional[str]
718+
destination: Optional[table.TableReference]
719+
total_rows: Optional[int]
720+
total_bytes_processed: Optional[int]
721+
slot_millis: Optional[int]
722+
723+
724+
@dataclasses.dataclass(frozen=True)
725+
class QueryReceivedEvent:
726+
"""Query received and acknowledged by the BigQuery API."""
727+
728+
billing_project: Optional[str]
729+
location: Optional[str]
730+
job_id: Optional[str]
731+
statement_type: Optional[str]
732+
state: Optional[str]
733+
query_plan: Optional[list[google.cloud.bigquery.job.query.QueryPlanEntry]]
734+
735+
736+
@dataclasses.dataclass(frozen=True)
737+
class QuerySentEvent:
738+
"""Query sent to BigQuery."""
739+
740+
query: str
741+
billing_project: Optional[str]
742+
location: Optional[str]
743+
job_id: Optional[str]
744+
request_id: Optional[str]
745+
746+
747+
class QueryRetryEvent(QuerySentEvent):
748+
"""Query sent another time because the previous failed."""
749+
750+
751+
class QuerySentEventFactory:
752+
"""Creates a QuerySentEvent first, then QueryRetryEvent after that."""
753+
754+
def __init__(self):
755+
self._event_constructor = QuerySentEvent
756+
757+
def __call__(self, **kwargs):
758+
result = self._event_constructor(**kwargs)
759+
self._event_constructor = QueryRetryEvent
760+
return result

google/cloud/bigquery/client.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3633,8 +3633,8 @@ def query_and_wait(
36333633
rate-limit-exceeded errors. Passing ``None`` disables
36343634
job retry. Not all jobs can be retried.
36353635
page_size (Optional[int]):
3636-
The maximum number of rows in each page of results from this
3637-
request. Non-positive values are ignored.
3636+
The maximum number of rows in each page of results from the
3637+
initial jobs.query request. Non-positive values are ignored.
36383638
max_results (Optional[int]):
36393639
The maximum total number of rows from this request.
36403640
@@ -3656,6 +3656,39 @@ def query_and_wait(
36563656
:class:`~google.cloud.bigquery.job.QueryJobConfig`
36573657
class.
36583658
"""
3659+
return self._query_and_wait_bigframes(
3660+
query,
3661+
job_config=job_config,
3662+
location=location,
3663+
project=project,
3664+
api_timeout=api_timeout,
3665+
wait_timeout=wait_timeout,
3666+
retry=retry,
3667+
job_retry=job_retry,
3668+
page_size=page_size,
3669+
max_results=max_results,
3670+
)
3671+
3672+
def _query_and_wait_bigframes(
3673+
self,
3674+
query,
3675+
*,
3676+
job_config: Optional[QueryJobConfig] = None,
3677+
location: Optional[str] = None,
3678+
project: Optional[str] = None,
3679+
api_timeout: TimeoutType = DEFAULT_TIMEOUT,
3680+
wait_timeout: Union[Optional[float], object] = POLLING_DEFAULT_VALUE,
3681+
retry: retries.Retry = DEFAULT_RETRY,
3682+
job_retry: retries.Retry = DEFAULT_JOB_RETRY,
3683+
page_size: Optional[int] = None,
3684+
max_results: Optional[int] = None,
3685+
callback = lambda _: None,
3686+
) -> RowIterator:
3687+
"""See query_and_wait.
3688+
3689+
This method has an extra callback parameter, which is used by bigframes
3690+
to create better progress bars.
3691+
"""
36593692
if project is None:
36603693
project = self.project
36613694

@@ -3681,6 +3714,7 @@ def query_and_wait(
36813714
job_retry=job_retry,
36823715
page_size=page_size,
36833716
max_results=max_results,
3717+
callback=callback,
36843718
)
36853719

36863720
def insert_rows(

google/cloud/bigquery/job/base.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
"""Base classes and helpers for job classes."""
1616

17+
from __future__ import annotations
18+
1719
from collections import namedtuple
1820
import copy
1921
import http
@@ -440,9 +442,12 @@ def configuration(self) -> _JobConfig:
440442
return configuration
441443

442444
@property
443-
def job_id(self):
445+
def job_id(self) -> Optional[str]:
444446
"""str: ID of the job."""
445-
return _helpers._get_sub_prop(self._properties, ["jobReference", "jobId"])
447+
return typing.cast(
448+
Optional[str],
449+
_helpers._get_sub_prop(self._properties, ["jobReference", "jobId"]),
450+
)
446451

447452
@property
448453
def parent_job_id(self):
@@ -493,18 +498,24 @@ def num_child_jobs(self):
493498
return int(count) if count is not None else 0
494499

495500
@property
496-
def project(self):
501+
def project(self) -> Optional[str]:
497502
"""Project bound to the job.
498503
499504
Returns:
500505
str: the project (derived from the client).
501506
"""
502-
return _helpers._get_sub_prop(self._properties, ["jobReference", "projectId"])
507+
return typing.cast(
508+
Optional[str],
509+
_helpers._get_sub_prop(self._properties, ["jobReference", "projectId"]),
510+
)
503511

504512
@property
505-
def location(self):
513+
def location(self) -> Optional[str]:
506514
"""str: Location where the job runs."""
507-
return _helpers._get_sub_prop(self._properties, ["jobReference", "location"])
515+
return typing.cast(
516+
Optional[str],
517+
_helpers._get_sub_prop(self._properties, ["jobReference", "location"]),
518+
)
508519

509520
@property
510521
def reservation_id(self):

google/cloud/bigquery/query.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1228,11 +1228,18 @@ def location(self):
12281228
12291229
See:
12301230
https://cloud.google.com/bigquery/docs/reference/rest/v2/jobs/query#body.QueryResponse.FIELDS.job_reference
1231+
or https://cloud.google.com/bigquery/docs/reference/rest/v2/jobs/query#body.QueryResponse.FIELDS.location
12311232
12321233
Returns:
12331234
str: Job ID of the query job.
12341235
"""
1235-
return self._properties.get("jobReference", {}).get("location")
1236+
location = self._properties.get("jobReference", {}).get("location")
1237+
1238+
# Sometimes there's no job, but we still want to get the location
1239+
# information. Prefer the value from job for backwards compatibilitity.
1240+
if not location:
1241+
location = self._properties.get("location")
1242+
return location
12361243

12371244
@property
12381245
def query_id(self) -> Optional[str]:

0 commit comments

Comments
 (0)