@@ -615,3 +615,80 @@ def test_query_and_wait_retries_job_for_DDL_queries(global_time_lock):
615615 _ , kwargs = calls [3 ]
616616 assert kwargs ["method" ] == "POST"
617617 assert kwargs ["path" ] == query_request_path
618+
619+
620+ @pytest .mark .parametrize (
621+ "result_retry_param" ,
622+ [
623+ pytest .param (
624+ {},
625+ id = "default retry {}" ,
626+ ),
627+ pytest .param (
628+ {
629+ "retry" : google .cloud .bigquery .retry .DEFAULT_RETRY .with_timeout (
630+ timeout = 10.0
631+ )
632+ },
633+ id = "custom retry object with timeout 10.0" ,
634+ ),
635+ ],
636+ )
637+ def test_retry_load_job_result (result_retry_param , PROJECT , DS_ID ):
638+ from google .cloud .bigquery .dataset import DatasetReference
639+ from google .cloud .bigquery .job .load import LoadJob
640+ import google .cloud .bigquery .retry
641+
642+ client = make_client ()
643+ conn = client ._connection = make_connection (
644+ dict (
645+ status = dict (state = "RUNNING" ),
646+ jobReference = {"jobId" : "id_1" },
647+ ),
648+ google .api_core .exceptions .ServiceUnavailable ("retry me" ),
649+ dict (
650+ status = dict (state = "DONE" ),
651+ jobReference = {"jobId" : "id_1" },
652+ statistics = {"load" : {"outputRows" : 1 }},
653+ ),
654+ )
655+
656+ table_ref = DatasetReference (project = PROJECT , dataset_id = DS_ID ).table ("new_table" )
657+ job = LoadJob ("id_1" , source_uris = None , destination = table_ref , client = client )
658+
659+ with mock .patch .object (
660+ client , "_call_api" , wraps = client ._call_api
661+ ) as wrapped_call_api :
662+ result = job .result (** result_retry_param )
663+
664+ assert job .state == "DONE"
665+ assert result .output_rows == 1
666+
667+ # Check that _call_api was called multiple times due to retry
668+ assert wrapped_call_api .call_count > 1
669+
670+ # Verify the retry object used in the calls to _call_api
671+ expected_retry = result_retry_param .get (
672+ "retry" , google .cloud .bigquery .retry .DEFAULT_RETRY
673+ )
674+
675+ for call in wrapped_call_api .mock_calls :
676+ name , args , kwargs = call
677+ # The retry object is the first positional argument to _call_api
678+ called_retry = args [0 ]
679+
680+ # We only care about the calls made during the job.result() polling
681+ if kwargs .get ("method" ) == "GET" and "jobs/id_1" in kwargs .get ("path" , "" ):
682+ assert called_retry ._predicate == expected_retry ._predicate
683+ assert called_retry ._initial == expected_retry ._initial
684+ assert called_retry ._maximum == expected_retry ._maximum
685+ assert called_retry ._multiplier == expected_retry ._multiplier
686+ assert called_retry ._deadline == expected_retry ._deadline
687+ if "retry" in result_retry_param :
688+ # Specifically check the timeout for the custom retry case
689+ assert called_retry ._timeout == 10.0
690+ else :
691+ assert called_retry ._timeout == expected_retry ._timeout
692+
693+ # The number of api_request calls should still be 3
694+ assert conn .api_request .call_count == 3
0 commit comments