3535predicates where it is safe to generate a new query ID.
3636"""
3737
38+ from __future__ import annotations
39+
3840import copy
41+ import dataclasses
3942import functools
4043import uuid
4144import textwrap
42- from typing import Any , Dict , Optional , TYPE_CHECKING , Union
45+ from typing import Any , Callable , Dict , Optional , TYPE_CHECKING , Union
4346import warnings
4447
4548import google .api_core .exceptions as core_exceptions
4649from google .api_core import retry as retries
4750
4851from google .cloud .bigquery import job
52+ import google .cloud .bigquery .job .query
4953import google .cloud .bigquery .query
5054from google .cloud .bigquery import table
5155import google .cloud .bigquery .retry
@@ -116,14 +120,20 @@ def query_jobs_insert(
116120 retry : Optional [retries .Retry ],
117121 timeout : Optional [float ],
118122 job_retry : Optional [retries .Retry ],
123+ callback : Callable ,
119124) -> job .QueryJob :
120125 """Initiate a query using jobs.insert.
121126
122127 See: https://cloud.google.com/bigquery/docs/reference/rest/v2/jobs/insert
128+
129+ Args:
130+ callback (Callable):
131+ A callback function used by bigframes to report query progress.
123132 """
124133 job_id_given = job_id is not None
125134 job_id_save = job_id
126135 job_config_save = job_config
136+ query_sent_factory = QuerySentEventFactory ()
127137
128138 def do_query ():
129139 # Make a copy now, so that original doesn't get changed by the process
@@ -136,6 +146,15 @@ def do_query():
136146
137147 try :
138148 query_job ._begin (retry = retry , timeout = timeout )
149+ callback (
150+ query_sent_factory (
151+ query = query ,
152+ billing_project = query_job .project ,
153+ location = query_job .location ,
154+ job_id = query_job .job_id ,
155+ request_id = None ,
156+ )
157+ )
139158 except core_exceptions .Conflict as create_exc :
140159 # The thought is if someone is providing their own job IDs and they get
141160 # their job ID generation wrong, this could end up returning results for
@@ -396,6 +415,7 @@ def query_and_wait(
396415 job_retry : Optional [retries .Retry ],
397416 page_size : Optional [int ] = None ,
398417 max_results : Optional [int ] = None ,
418+ callback : Callable = lambda _ : None ,
399419) -> table .RowIterator :
400420 """Run the query, wait for it to finish, and return the results.
401421
@@ -415,9 +435,8 @@ def query_and_wait(
415435 location (Optional[str]):
416436 Location where to run the job. Must match the location of the
417437 table used in the query as well as the destination table.
418- project (Optional[str]):
419- Project ID of the project of where to run the job. Defaults
420- to the client's project.
438+ project (str):
439+ Project ID of the project of where to run the job.
421440 api_timeout (Optional[float]):
422441 The number of seconds to wait for the underlying HTTP transport
423442 before using ``retry``.
@@ -441,6 +460,8 @@ def query_and_wait(
441460 request. Non-positive values are ignored.
442461 max_results (Optional[int]):
443462 The maximum total number of rows from this request.
463+ callback (Callable):
464+ A callback function used by bigframes to report query progress.
444465
445466 Returns:
446467 google.cloud.bigquery.table.RowIterator:
@@ -479,12 +500,14 @@ def query_and_wait(
479500 retry = retry ,
480501 timeout = api_timeout ,
481502 job_retry = job_retry ,
503+ callback = callback ,
482504 ),
483505 api_timeout = api_timeout ,
484506 wait_timeout = wait_timeout ,
485507 retry = retry ,
486508 page_size = page_size ,
487509 max_results = max_results ,
510+ callback = callback ,
488511 )
489512
490513 path = _to_query_path (project )
@@ -496,10 +519,23 @@ def query_and_wait(
496519 if client .default_job_creation_mode :
497520 request_body ["jobCreationMode" ] = client .default_job_creation_mode
498521
522+ query_sent_factory = QuerySentEventFactory ()
523+
499524 def do_query ():
500- request_body ["requestId" ] = make_job_id ()
525+ request_id = make_job_id ()
526+ request_body ["requestId" ] = request_id
501527 span_attributes = {"path" : path }
502528
529+ callback (
530+ query_sent_factory (
531+ query = query ,
532+ billing_project = project ,
533+ location = location ,
534+ job_id = None ,
535+ request_id = request_id ,
536+ )
537+ )
538+
503539 # For easier testing, handle the retries ourselves.
504540 if retry is not None :
505541 response = retry (client ._call_api )(
@@ -542,8 +578,21 @@ def do_query():
542578 retry = retry ,
543579 page_size = page_size ,
544580 max_results = max_results ,
581+ callback = callback ,
545582 )
546583
584+ callback (
585+ QueryFinishedEvent (
586+ billing_project = project ,
587+ location = query_results .location ,
588+ query_id = query_results .query_id ,
589+ job_id = query_results .job_id ,
590+ total_rows = query_results .total_rows ,
591+ total_bytes_processed = query_results .total_bytes_processed ,
592+ slot_millis = query_results .slot_millis ,
593+ destination = None ,
594+ )
595+ )
547596 return table .RowIterator (
548597 client = client ,
549598 api_request = functools .partial (client ._call_api , retry , timeout = api_timeout ),
@@ -611,19 +660,43 @@ def _wait_or_cancel(
611660 retry : Optional [retries .Retry ],
612661 page_size : Optional [int ],
613662 max_results : Optional [int ],
663+ callback : Callable ,
614664) -> table .RowIterator :
615665 """Wait for a job to complete and return the results.
616666
617667 If we can't return the results within the ``wait_timeout``, try to cancel
618668 the job.
619669 """
620670 try :
621- return job .result (
671+ callback (
672+ QueryReceivedEvent (
673+ billing_project = job .project ,
674+ location = job .location ,
675+ job_id = job .job_id ,
676+ statement_type = job .statement_type ,
677+ state = job .state ,
678+ query_plan = job .query_plan ,
679+ )
680+ )
681+ query_results = job .result (
622682 page_size = page_size ,
623683 max_results = max_results ,
624684 retry = retry ,
625685 timeout = wait_timeout ,
626686 )
687+ callback (
688+ QueryFinishedEvent (
689+ billing_project = job .project ,
690+ location = query_results .location ,
691+ query_id = query_results .query_id ,
692+ job_id = query_results .job_id ,
693+ total_rows = query_results .total_rows ,
694+ total_bytes_processed = query_results .total_bytes_processed ,
695+ slot_millis = query_results .slot_millis ,
696+ destination = job .destination ,
697+ )
698+ )
699+ return query_results
627700 except Exception :
628701 # Attempt to cancel the job since we can't return the results.
629702 try :
@@ -632,3 +705,56 @@ def _wait_or_cancel(
632705 # Don't eat the original exception if cancel fails.
633706 pass
634707 raise
708+
709+
710+ @dataclasses .dataclass (frozen = True )
711+ class QueryFinishedEvent :
712+ """Query finished successfully."""
713+
714+ billing_project : Optional [str ]
715+ location : Optional [str ]
716+ query_id : Optional [str ]
717+ job_id : Optional [str ]
718+ destination : Optional [table .TableReference ]
719+ total_rows : Optional [int ]
720+ total_bytes_processed : Optional [int ]
721+ slot_millis : Optional [int ]
722+
723+
724+ @dataclasses .dataclass (frozen = True )
725+ class QueryReceivedEvent :
726+ """Query received and acknowledged by the BigQuery API."""
727+
728+ billing_project : Optional [str ]
729+ location : Optional [str ]
730+ job_id : Optional [str ]
731+ statement_type : Optional [str ]
732+ state : Optional [str ]
733+ query_plan : Optional [list [google .cloud .bigquery .job .query .QueryPlanEntry ]]
734+
735+
736+ @dataclasses .dataclass (frozen = True )
737+ class QuerySentEvent :
738+ """Query sent to BigQuery."""
739+
740+ query : str
741+ billing_project : Optional [str ]
742+ location : Optional [str ]
743+ job_id : Optional [str ]
744+ request_id : Optional [str ]
745+
746+
747+ class QueryRetryEvent (QuerySentEvent ):
748+ """Query sent another time because the previous failed."""
749+
750+
751+ class QuerySentEventFactory :
752+ """Creates a QuerySentEvent first, then QueryRetryEvent after that."""
753+
754+ def __init__ (self ):
755+ self ._event_constructor = QuerySentEvent
756+
757+ def __call__ (self , ** kwargs ):
758+ result = self ._event_constructor (** kwargs )
759+ self ._event_constructor = QueryRetryEvent
760+ return result
0 commit comments