3535predicates where it is safe to generate a new query ID.
3636"""
3737
38+ from __future__ import annotations
39+
3840import copy
41+ import dataclasses
42+ import datetime
3943import functools
4044import uuid
4145import textwrap
42- from typing import Any , Dict , Optional , TYPE_CHECKING , Union
46+ from typing import Any , Callable , Dict , Optional , TYPE_CHECKING , Union
4347import warnings
4448
4549import google .api_core .exceptions as core_exceptions
4650from google .api_core import retry as retries
4751
4852from google .cloud .bigquery import job
53+ import google .cloud .bigquery .job .query
4954import google .cloud .bigquery .query
5055from google .cloud .bigquery import table
5156import google .cloud .bigquery .retry
@@ -116,14 +121,21 @@ def query_jobs_insert(
116121 retry : Optional [retries .Retry ],
117122 timeout : Optional [float ],
118123 job_retry : Optional [retries .Retry ],
124+ * ,
125+ callback : Callable = lambda _ : None ,
119126) -> job .QueryJob :
120127 """Initiate a query using jobs.insert.
121128
122129 See: https://cloud.google.com/bigquery/docs/reference/rest/v2/jobs/insert
130+
131+ Args:
132+ callback (Callable):
133+ A callback function used by bigframes to report query progress.
123134 """
124135 job_id_given = job_id is not None
125136 job_id_save = job_id
126137 job_config_save = job_config
138+ query_sent_factory = QuerySentEventFactory ()
127139
128140 def do_query ():
129141 # Make a copy now, so that original doesn't get changed by the process
@@ -136,6 +148,16 @@ def do_query():
136148
137149 try :
138150 query_job ._begin (retry = retry , timeout = timeout )
151+ if job_config is not None and not job_config .dry_run :
152+ callback (
153+ query_sent_factory (
154+ query = query ,
155+ billing_project = query_job .project ,
156+ location = query_job .location ,
157+ job_id = query_job .job_id ,
158+ request_id = None ,
159+ )
160+ )
139161 except core_exceptions .Conflict as create_exc :
140162 # The thought is if someone is providing their own job IDs and they get
141163 # their job ID generation wrong, this could end up returning results for
@@ -396,6 +418,7 @@ def query_and_wait(
396418 job_retry : Optional [retries .Retry ],
397419 page_size : Optional [int ] = None ,
398420 max_results : Optional [int ] = None ,
421+ callback : Callable = lambda _ : None ,
399422) -> table .RowIterator :
400423 """Run the query, wait for it to finish, and return the results.
401424
@@ -415,9 +438,8 @@ def query_and_wait(
415438 location (Optional[str]):
416439 Location where to run the job. Must match the location of the
417440 table used in the query as well as the destination table.
418- project (Optional[str]):
419- Project ID of the project of where to run the job. Defaults
420- to the client's project.
441+ project (str):
442+ Project ID of the project of where to run the job.
421443 api_timeout (Optional[float]):
422444 The number of seconds to wait for the underlying HTTP transport
423445 before using ``retry``.
@@ -441,6 +463,8 @@ def query_and_wait(
441463 request. Non-positive values are ignored.
442464 max_results (Optional[int]):
443465 The maximum total number of rows from this request.
466+ callback (Callable):
467+ A callback function used by bigframes to report query progress.
444468
445469 Returns:
446470 google.cloud.bigquery.table.RowIterator:
@@ -479,12 +503,14 @@ def query_and_wait(
479503 retry = retry ,
480504 timeout = api_timeout ,
481505 job_retry = job_retry ,
506+ callback = callback ,
482507 ),
483508 api_timeout = api_timeout ,
484509 wait_timeout = wait_timeout ,
485510 retry = retry ,
486511 page_size = page_size ,
487512 max_results = max_results ,
513+ callback = callback ,
488514 )
489515
490516 path = _to_query_path (project )
@@ -496,10 +522,24 @@ def query_and_wait(
496522 if client .default_job_creation_mode :
497523 request_body ["jobCreationMode" ] = client .default_job_creation_mode
498524
525+ query_sent_factory = QuerySentEventFactory ()
526+
499527 def do_query ():
500- request_body ["requestId" ] = make_job_id ()
528+ request_id = make_job_id ()
529+ request_body ["requestId" ] = request_id
501530 span_attributes = {"path" : path }
502531
532+ if "dryRun" not in request_body :
533+ callback (
534+ query_sent_factory (
535+ query = query ,
536+ billing_project = project ,
537+ location = location ,
538+ job_id = None ,
539+ request_id = request_id ,
540+ )
541+ )
542+
503543 # For easier testing, handle the retries ourselves.
504544 if retry is not None :
505545 response = retry (client ._call_api )(
@@ -542,8 +582,25 @@ def do_query():
542582 retry = retry ,
543583 page_size = page_size ,
544584 max_results = max_results ,
585+ callback = callback ,
545586 )
546587
588+ if "dryRun" not in request_body :
589+ callback (
590+ QueryFinishedEvent (
591+ billing_project = project ,
592+ location = query_results .location ,
593+ query_id = query_results .query_id ,
594+ job_id = query_results .job_id ,
595+ total_rows = query_results .total_rows ,
596+ total_bytes_processed = query_results .total_bytes_processed ,
597+ slot_millis = query_results .slot_millis ,
598+ destination = None ,
599+ created = query_results .created ,
600+ started = query_results .started ,
601+ ended = query_results .ended ,
602+ )
603+ )
547604 return table .RowIterator (
548605 client = client ,
549606 api_request = functools .partial (client ._call_api , retry , timeout = api_timeout ),
@@ -614,19 +671,52 @@ def _wait_or_cancel(
614671 retry : Optional [retries .Retry ],
615672 page_size : Optional [int ],
616673 max_results : Optional [int ],
674+ * ,
675+ callback : Callable = lambda _ : None ,
617676) -> table .RowIterator :
618677 """Wait for a job to complete and return the results.
619678
620679 If we can't return the results within the ``wait_timeout``, try to cancel
621680 the job.
622681 """
623682 try :
624- return job .result (
683+ if not job .dry_run :
684+ callback (
685+ QueryReceivedEvent (
686+ billing_project = job .project ,
687+ location = job .location ,
688+ job_id = job .job_id ,
689+ statement_type = job .statement_type ,
690+ state = job .state ,
691+ query_plan = job .query_plan ,
692+ created = job .created ,
693+ started = job .started ,
694+ ended = job .ended ,
695+ )
696+ )
697+ query_results = job .result (
625698 page_size = page_size ,
626699 max_results = max_results ,
627700 retry = retry ,
628701 timeout = wait_timeout ,
629702 )
703+ if not job .dry_run :
704+ callback (
705+ QueryFinishedEvent (
706+ billing_project = job .project ,
707+ location = query_results .location ,
708+ query_id = query_results .query_id ,
709+ job_id = query_results .job_id ,
710+ total_rows = query_results .total_rows ,
711+ total_bytes_processed = query_results .total_bytes_processed ,
712+ slot_millis = query_results .slot_millis ,
713+ destination = job .destination ,
714+ created = job .created ,
715+ started = job .started ,
716+ ended = job .ended ,
717+ )
718+ )
719+ return query_results
630720 except Exception :
631721 # Attempt to cancel the job since we can't return the results.
632722 try :
@@ -635,3 +725,62 @@ def _wait_or_cancel(
635725 # Don't eat the original exception if cancel fails.
636726 pass
637727 raise
728+
729+
730+ @dataclasses .dataclass (frozen = True )
731+ class QueryFinishedEvent :
732+ """Query finished successfully."""
733+
734+ billing_project : Optional [str ]
735+ location : Optional [str ]
736+ query_id : Optional [str ]
737+ job_id : Optional [str ]
738+ destination : Optional [table .TableReference ]
739+ total_rows : Optional [int ]
740+ total_bytes_processed : Optional [int ]
741+ slot_millis : Optional [int ]
742+ created : Optional [datetime .datetime ]
743+ started : Optional [datetime .datetime ]
744+ ended : Optional [datetime .datetime ]
745+
746+
747+ @dataclasses .dataclass (frozen = True )
748+ class QueryReceivedEvent :
749+ """Query received and acknowledged by the BigQuery API."""
750+
751+ billing_project : Optional [str ]
752+ location : Optional [str ]
753+ job_id : Optional [str ]
754+ statement_type : Optional [str ]
755+ state : Optional [str ]
756+ query_plan : Optional [list [google .cloud .bigquery .job .query .QueryPlanEntry ]]
757+ created : Optional [datetime .datetime ]
758+ started : Optional [datetime .datetime ]
759+ ended : Optional [datetime .datetime ]
760+
761+
762+ @dataclasses .dataclass (frozen = True )
763+ class QuerySentEvent :
764+ """Query sent to BigQuery."""
765+
766+ query : str
767+ billing_project : Optional [str ]
768+ location : Optional [str ]
769+ job_id : Optional [str ]
770+ request_id : Optional [str ]
771+
772+
773+ class QueryRetryEvent (QuerySentEvent ):
774+ """Query sent another time because the previous attempt failed."""
775+
776+
777+ class QuerySentEventFactory :
778+ """Creates a QuerySentEvent first, then QueryRetryEvent after that."""
779+
780+ def __init__ (self ):
781+ self ._event_constructor = QuerySentEvent
782+
783+ def __call__ (self , ** kwargs ):
784+ result = self ._event_constructor (** kwargs )
785+ self ._event_constructor = QueryRetryEvent
786+ return result
0 commit comments