3535predicates where it is safe to generate a new query ID.
3636"""
3737
38+ from __future__ import annotations
39+
3840import copy
41+ import dataclasses
42+ import datetime
3943import functools
4044import uuid
4145import textwrap
42- from typing import Any , Dict , Optional , TYPE_CHECKING , Union
46+ from typing import Any , Callable , Dict , Optional , TYPE_CHECKING , Union
4347import warnings
4448
4549import google .api_core .exceptions as core_exceptions
4650from google .api_core import retry as retries
4751
4852from google .cloud .bigquery import job
53+ import google .cloud .bigquery .job .query
4954import google .cloud .bigquery .query
5055from google .cloud .bigquery import table
5156import google .cloud .bigquery .retry
@@ -116,14 +121,21 @@ def query_jobs_insert(
116121 retry : Optional [retries .Retry ],
117122 timeout : Optional [float ],
118123 job_retry : Optional [retries .Retry ],
124+ * ,
125+ callback : Callable = lambda _ : None ,
119126) -> job .QueryJob :
120127 """Initiate a query using jobs.insert.
121128
122129 See: https://cloud.google.com/bigquery/docs/reference/rest/v2/jobs/insert
130+
131+ Args:
132+ callback (Callable):
133+ A callback function used by bigframes to report query progress.
123134 """
124135 job_id_given = job_id is not None
125136 job_id_save = job_id
126137 job_config_save = job_config
138+ query_sent_factory = QuerySentEventFactory ()
127139
128140 def do_query ():
129141 # Make a copy now, so that original doesn't get changed by the process
@@ -136,6 +148,16 @@ def do_query():
136148
137149 try :
138150 query_job ._begin (retry = retry , timeout = timeout )
151+ if job_config is not None and not job_config .dry_run :
152+ callback (
153+ query_sent_factory (
154+ query = query ,
155+ billing_project = query_job .project ,
156+ location = query_job .location ,
157+ job_id = query_job .job_id ,
158+ request_id = None ,
159+ )
160+ )
139161 except core_exceptions .Conflict as create_exc :
140162 # The thought is if someone is providing their own job IDs and they get
141163 # their job ID generation wrong, this could end up returning results for
@@ -396,6 +418,7 @@ def query_and_wait(
396418 job_retry : Optional [retries .Retry ],
397419 page_size : Optional [int ] = None ,
398420 max_results : Optional [int ] = None ,
421+ callback : Callable = lambda _ : None ,
399422) -> table .RowIterator :
400423 """Run the query, wait for it to finish, and return the results.
401424
@@ -415,9 +438,8 @@ def query_and_wait(
415438 location (Optional[str]):
416439 Location where to run the job. Must match the location of the
417440 table used in the query as well as the destination table.
418- project (Optional[str]):
419- Project ID of the project of where to run the job. Defaults
420- to the client's project.
441+ project (str):
442+ Project ID of the project of where to run the job.
421443 api_timeout (Optional[float]):
422444 The number of seconds to wait for the underlying HTTP transport
423445 before using ``retry``.
@@ -441,6 +463,8 @@ def query_and_wait(
441463 request. Non-positive values are ignored.
442464 max_results (Optional[int]):
443465 The maximum total number of rows from this request.
466+ callback (Callable):
467+ A callback function used by bigframes to report query progress.
444468
445469 Returns:
446470 google.cloud.bigquery.table.RowIterator:
@@ -479,12 +503,14 @@ def query_and_wait(
479503 retry = retry ,
480504 timeout = api_timeout ,
481505 job_retry = job_retry ,
506+ callback = callback ,
482507 ),
483508 api_timeout = api_timeout ,
484509 wait_timeout = wait_timeout ,
485510 retry = retry ,
486511 page_size = page_size ,
487512 max_results = max_results ,
513+ callback = callback ,
488514 )
489515
490516 path = _to_query_path (project )
@@ -496,10 +522,24 @@ def query_and_wait(
496522 if client .default_job_creation_mode :
497523 request_body ["jobCreationMode" ] = client .default_job_creation_mode
498524
525+ query_sent_factory = QuerySentEventFactory ()
526+
499527 def do_query ():
500- request_body ["requestId" ] = make_job_id ()
528+ request_id = make_job_id ()
529+ request_body ["requestId" ] = request_id
501530 span_attributes = {"path" : path }
502531
532+ if "dryRun" not in request_body :
533+ callback (
534+ query_sent_factory (
535+ query = query ,
536+ billing_project = project ,
537+ location = location ,
538+ job_id = None ,
539+ request_id = request_id ,
540+ )
541+ )
542+
503543 # For easier testing, handle the retries ourselves.
504544 if retry is not None :
505545 response = retry (client ._call_api )(
@@ -542,8 +582,25 @@ def do_query():
542582 retry = retry ,
543583 page_size = page_size ,
544584 max_results = max_results ,
585+ callback = callback ,
545586 )
546587
588+ if "dryRun" not in request_body :
589+ callback (
590+ QueryFinishedEvent (
591+ billing_project = project ,
592+ location = query_results .location ,
593+ query_id = query_results .query_id ,
594+ job_id = query_results .job_id ,
595+ total_rows = query_results .total_rows ,
596+ total_bytes_processed = query_results .total_bytes_processed ,
597+ slot_millis = query_results .slot_millis ,
598+ destination = None ,
599+ created = query_results .created ,
600+ started = query_results .started ,
601+ ended = query_results .ended ,
602+ )
603+ )
547604 return table .RowIterator (
548605 client = client ,
549606 api_request = functools .partial (client ._call_api , retry , timeout = api_timeout ),
@@ -561,6 +618,9 @@ def do_query():
561618 query = query ,
562619 total_bytes_processed = query_results .total_bytes_processed ,
563620 slot_millis = query_results .slot_millis ,
621+ created = query_results .created ,
622+ started = query_results .started ,
623+ ended = query_results .ended ,
564624 )
565625
566626 if job_retry is not None :
@@ -598,6 +658,9 @@ def _supported_by_jobs_query(request_body: Dict[str, Any]) -> bool:
598658 "requestId" ,
599659 "createSession" ,
600660 "writeIncrementalResults" ,
661+ "jobTimeoutMs" ,
662+ "reservation" ,
663+ "maxSlots" ,
601664 }
602665
603666 unsupported_keys = request_keys - keys_allowlist
@@ -611,19 +674,52 @@ def _wait_or_cancel(
611674 retry : Optional [retries .Retry ],
612675 page_size : Optional [int ],
613676 max_results : Optional [int ],
677+ * ,
678+ callback : Callable = lambda _ : None ,
614679) -> table .RowIterator :
615680 """Wait for a job to complete and return the results.
616681
617682 If we can't return the results within the ``wait_timeout``, try to cancel
618683 the job.
619684 """
620685 try :
621- return job .result (
686+ if not job .dry_run :
687+ callback (
688+ QueryReceivedEvent (
689+ billing_project = job .project ,
690+ location = job .location ,
691+ job_id = job .job_id ,
692+ statement_type = job .statement_type ,
693+ state = job .state ,
694+ query_plan = job .query_plan ,
695+ created = job .created ,
696+ started = job .started ,
697+ ended = job .ended ,
698+ )
699+ )
700+ query_results = job .result (
622701 page_size = page_size ,
623702 max_results = max_results ,
624703 retry = retry ,
625704 timeout = wait_timeout ,
626705 )
706+ if not job .dry_run :
707+ callback (
708+ QueryFinishedEvent (
709+ billing_project = job .project ,
710+ location = query_results .location ,
711+ query_id = query_results .query_id ,
712+ job_id = query_results .job_id ,
713+ total_rows = query_results .total_rows ,
714+ total_bytes_processed = query_results .total_bytes_processed ,
715+ slot_millis = query_results .slot_millis ,
716+ destination = job .destination ,
717+ created = job .created ,
718+ started = job .started ,
719+ ended = job .ended ,
720+ )
721+ )
722+ return query_results
627723 except Exception :
628724 # Attempt to cancel the job since we can't return the results.
629725 try :
@@ -632,3 +728,62 @@ def _wait_or_cancel(
632728 # Don't eat the original exception if cancel fails.
633729 pass
634730 raise
731+
732+
733+ @dataclasses .dataclass (frozen = True )
734+ class QueryFinishedEvent :
735+ """Query finished successfully."""
736+
737+ billing_project : Optional [str ]
738+ location : Optional [str ]
739+ query_id : Optional [str ]
740+ job_id : Optional [str ]
741+ destination : Optional [table .TableReference ]
742+ total_rows : Optional [int ]
743+ total_bytes_processed : Optional [int ]
744+ slot_millis : Optional [int ]
745+ created : Optional [datetime .datetime ]
746+ started : Optional [datetime .datetime ]
747+ ended : Optional [datetime .datetime ]
748+
749+
750+ @dataclasses .dataclass (frozen = True )
751+ class QueryReceivedEvent :
752+ """Query received and acknowledged by the BigQuery API."""
753+
754+ billing_project : Optional [str ]
755+ location : Optional [str ]
756+ job_id : Optional [str ]
757+ statement_type : Optional [str ]
758+ state : Optional [str ]
759+ query_plan : Optional [list [google .cloud .bigquery .job .query .QueryPlanEntry ]]
760+ created : Optional [datetime .datetime ]
761+ started : Optional [datetime .datetime ]
762+ ended : Optional [datetime .datetime ]
763+
764+
765+ @dataclasses .dataclass (frozen = True )
766+ class QuerySentEvent :
767+ """Query sent to BigQuery."""
768+
769+ query : str
770+ billing_project : Optional [str ]
771+ location : Optional [str ]
772+ job_id : Optional [str ]
773+ request_id : Optional [str ]
774+
775+
776+ class QueryRetryEvent (QuerySentEvent ):
777+ """Query sent another time because the previous attempt failed."""
778+
779+
780+ class QuerySentEventFactory :
781+ """Creates a QuerySentEvent first, then QueryRetryEvent after that."""
782+
783+ def __init__ (self ):
784+ self ._event_constructor = QuerySentEvent
785+
786+ def __call__ (self , ** kwargs ):
787+ result = self ._event_constructor (** kwargs )
788+ self ._event_constructor = QueryRetryEvent
789+ return result
0 commit comments