diff --git a/google/cloud/bigquery/async_client.py b/google/cloud/bigquery/async_client.py new file mode 100644 index 000000000..297fba585 --- /dev/null +++ b/google/cloud/bigquery/async_client.py @@ -0,0 +1,203 @@ +import sys +from google.cloud.bigquery.client import * +from google.cloud.bigquery.client import ( + _add_server_timeout_header, + _extract_job_reference, +) +from google.cloud.bigquery.opentelemetry_tracing import async_create_span +from google.cloud.bigquery import _job_helpers +from google.cloud.bigquery.table import * +from google.cloud.bigquery.table import _table_arg_to_table_ref +from google.api_core.page_iterator import HTTPIterator +from google.cloud.bigquery.query import _QueryResults +from google.cloud.bigquery.retry import ( + DEFAULT_ASYNC_JOB_RETRY, + DEFAULT_ASYNC_RETRY, + DEFAULT_TIMEOUT, +) +from google.api_core import retry_async as retries + +if sys.version_info >= (3, 9): + import asyncio + import aiohttp + from google.auth.transport import _aiohttp_requests + +# This code is experimental + +_MIN_GET_QUERY_RESULTS_TIMEOUT = 120 + + +class AsyncClient: + def __init__(self, *args, **kwargs): + self._client = Client(*args, **kwargs) + + async def get_job( + self, + job_id: Union[str, job.LoadJob, job.CopyJob, job.ExtractJob, job.QueryJob], + project: Optional[str] = None, + location: Optional[str] = None, + retry: retries.AsyncRetry = DEFAULT_ASYNC_RETRY, + timeout: TimeoutType = DEFAULT_TIMEOUT, + ) -> Union[job.LoadJob, job.CopyJob, job.ExtractJob, job.QueryJob, job.UnknownJob]: + extra_params = {"projection": "full"} + + project, location, job_id = _extract_job_reference( + job_id, project=project, location=location + ) + + if project is None: + project = self._client.project + + if location is None: + location = self._client.location + + if location is not None: + extra_params["location"] = location + + path = "/projects/{}/jobs/{}".format(project, job_id) + + span_attributes = {"path": path, "job_id": job_id, "location": location} + + resource = await self._call_api( + retry, + span_name="BigQuery.getJob", + span_attributes=span_attributes, + method="GET", + path=path, + query_params=extra_params, + timeout=timeout, + ) + + return self._client.job_from_resource(await resource) + + async def _get_query_results( # make async + self, + job_id: str, + retry: retries.AsyncRetry, + project: Optional[str] = None, + timeout_ms: Optional[int] = None, + location: Optional[str] = None, + timeout: TimeoutType = DEFAULT_TIMEOUT, + ) -> _QueryResults: + extra_params: Dict[str, Any] = {"maxResults": 0} + + if timeout is not None: + if not isinstance(timeout, (int, float)): + timeout = _MIN_GET_QUERY_RESULTS_TIMEOUT + else: + timeout = max(timeout, _MIN_GET_QUERY_RESULTS_TIMEOUT) + + if project is None: + project = self._client.project + + if timeout_ms is not None: + extra_params["timeoutMs"] = timeout_ms + + if location is None: + location = self._client.location + + if location is not None: + extra_params["location"] = location + + path = "/projects/{}/queries/{}".format(project, job_id) + + # This call is typically made in a polling loop that checks whether the + # job is complete (from QueryJob.done(), called ultimately from + # QueryJob.result()). So we don't need to poll here. + span_attributes = {"path": path} + resource = await self._call_api( + retry, + span_name="BigQuery.getQueryResults", + span_attributes=span_attributes, + method="GET", + path=path, + query_params=extra_params, + timeout=timeout, + ) + return _QueryResults.from_api_repr(resource) + + async def get_table( # make async + self, + table: Union[Table, TableReference, TableListItem, str], + retry: retries.AsyncRetry = DEFAULT_ASYNC_RETRY, + timeout: TimeoutType = DEFAULT_TIMEOUT, + ) -> Table: + table_ref = _table_arg_to_table_ref(table, default_project=self._client.project) + path = table_ref.path + span_attributes = {"path": path} + api_response = await self._call_api( + retry, + span_name="BigQuery.getTable", + span_attributes=span_attributes, + method="GET", + path=path, + timeout=timeout, + ) + + return Table.from_api_repr(api_response) + + async def list_partitions( # make async + self, + table: Union[Table, TableReference, TableListItem, str], + retry: retries.AsyncRetry = DEFAULT_ASYNC_RETRY, + timeout: TimeoutType = DEFAULT_TIMEOUT, + ) -> Sequence[str]: + table = _table_arg_to_table_ref(table, default_project=self._client.project) + meta_table = await self.get_table( + TableReference( + DatasetReference(table.project, table.dataset_id), + "%s$__PARTITIONS_SUMMARY__" % table.table_id, + ), + retry=retry, + timeout=timeout, + ) + + subset = [col for col in meta_table.schema if col.name == "partition_id"] + return [ + row[0] + for row in self._client.list_rows( + meta_table, selected_fields=subset, retry=retry, timeout=timeout + ) + ] + + async def _call_api( + self, + retry: Optional[retries.AsyncRetry] = None, + span_name: Optional[str] = None, + span_attributes: Optional[Dict] = None, + job_ref=None, + headers: Optional[Dict[str, str]] = None, + **kwargs, + ): + + kwargs = _add_server_timeout_header(headers, kwargs) + + # CREATE THIN WRAPPER OVER _AIOHTTP_REQUESTS (wip) + + DEFAULT_API_ENDPOINT = "https://bigquery.googleapis.com" + + kwargs['url'] = DEFAULT_API_ENDPOINT + kwargs.pop('path') + + if kwargs.get('query_params'): + kwargs['params'] = kwargs.pop('query_params') + + async with _aiohttp_requests.AuthorizedSession(self._client._credentials) as authed_session: + response = await authed_session.request( + **kwargs + ) + + + if retry: + response = retry(response) + + if span_name is not None: + async with async_create_span( + name=span_name, + attributes=span_attributes, + client=self._client, + job_ref=job_ref, + ): + return response() # Await the asynchronous call + + return response() # Await the asynchronous call + diff --git a/google/cloud/bigquery/opentelemetry_tracing.py b/google/cloud/bigquery/opentelemetry_tracing.py index e2a05e4d0..c1594c1a2 100644 --- a/google/cloud/bigquery/opentelemetry_tracing.py +++ b/google/cloud/bigquery/opentelemetry_tracing.py @@ -13,7 +13,7 @@ # limitations under the License. import logging -from contextlib import contextmanager +from contextlib import contextmanager, asynccontextmanager from google.api_core.exceptions import GoogleAPICallError # type: ignore logger = logging.getLogger(__name__) @@ -86,6 +86,37 @@ def create_span(name, attributes=None, client=None, job_ref=None): raise +@asynccontextmanager +async def async_create_span(name, attributes=None, client=None, job_ref=None): + """Asynchronous context manager for creating and exporting OpenTelemetry spans.""" + global _warned_telemetry + final_attributes = _get_final_span_attributes(attributes, client, job_ref) + + if not HAS_OPENTELEMETRY: + if not _warned_telemetry: + logger.debug( + "This service is instrumented using OpenTelemetry. " + "OpenTelemetry or one of its components could not be imported; " + "please add compatible versions of opentelemetry-api and " + "opentelemetry-instrumentation packages in order to get BigQuery " + "Tracing data." + ) + _warned_telemetry = True + yield None + return + tracer = trace.get_tracer(__name__) + + async with tracer.start_as_current_span( + name=name, attributes=final_attributes + ) as span: + try: + yield span + except GoogleAPICallError as error: + if error.code is not None: + span.set_status(Status(http_status_to_status_code(error.code))) + raise + + def _get_final_span_attributes(attributes=None, client=None, job_ref=None): """Compiles attributes from: client, job_ref, user-provided attributes. diff --git a/google/cloud/bigquery/retry.py b/google/cloud/bigquery/retry.py index 01b127972..f49247433 100644 --- a/google/cloud/bigquery/retry.py +++ b/google/cloud/bigquery/retry.py @@ -13,7 +13,7 @@ # limitations under the License. from google.api_core import exceptions -from google.api_core import retry +from google.api_core import retry, retry_async from google.auth import exceptions as auth_exceptions # type: ignore import requests.exceptions @@ -90,3 +90,15 @@ def _job_should_retry(exc): """ The default job retry object. """ + +DEFAULT_ASYNC_RETRY = retry_async.AsyncRetry( + predicate=_should_retry, deadline=_DEFAULT_RETRY_DEADLINE +) # deadline is deprecated + +DEFAULT_ASYNC_JOB_RETRY = retry_async.AsyncRetry( + predicate=_job_should_retry, + deadline=_DEFAULT_JOB_DEADLINE, # deadline is deprecated +) +# additional predicate cases for async modes? +# timeout? +# how is that expressed?, maximum retry based? diff --git a/noxfile.py b/noxfile.py index c31d098b8..c48b5c5ad 100644 --- a/noxfile.py +++ b/noxfile.py @@ -79,9 +79,10 @@ def default(session, install_extras=True): "-c", constraints_path, ) + session.install("asyncmock", "pytest-asyncio") - if install_extras and session.python in ["3.11", "3.12"]: - install_target = ".[bqstorage,ipywidgets,pandas,tqdm,opentelemetry]" + if install_extras and session.python in ["3.12"]: + install_target = ".[bqstorage,ipywidgets,pandas,tqdm,opentelemetry,aiohttp]" elif install_extras: install_target = ".[all]" else: @@ -104,6 +105,9 @@ def default(session, install_extras=True): *session.posargs, ) + # Having positional arguments means the user wants to run specific tests. + # Best not to add additional tests to that list. + @nox.session(python=UNIT_TEST_PYTHON_VERSIONS) def unit(session): @@ -188,8 +192,8 @@ def system(session): # Data Catalog needed for the column ACL test with a real Policy Tag. session.install("google-cloud-datacatalog", "-c", constraints_path) - if session.python in ["3.11", "3.12"]: - extras = "[bqstorage,ipywidgets,pandas,tqdm,opentelemetry]" + if session.python in ["3.12"]: + extras = "[bqstorage,ipywidgets,pandas,tqdm,opentelemetry,aiohttp]" # look at geopandas to see if it supports 3.11/3.12 (up to 3.11) else: extras = "[all]" session.install("-e", f".{extras}", "-c", constraints_path) @@ -254,8 +258,8 @@ def snippets(session): session.install("google-cloud-storage", "-c", constraints_path) session.install("grpcio", "-c", constraints_path) - if session.python in ["3.11", "3.12"]: - extras = "[bqstorage,ipywidgets,pandas,tqdm,opentelemetry]" + if session.python in ["3.12"]: + extras = "[bqstorage,ipywidgets,pandas,tqdm,opentelemetry,aiohttp]" else: extras = "[all]" session.install("-e", f".{extras}", "-c", constraints_path) diff --git a/setup.py b/setup.py index 5a35f4136..7d672d239 100644 --- a/setup.py +++ b/setup.py @@ -84,6 +84,9 @@ "proto-plus >= 1.15.0, <2.0.0dev", "protobuf>=3.19.5,<5.0.0dev,!=3.20.0,!=3.20.1,!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5", # For the legacy proto-based types. ], + "aiohttp": [ + "google-auth[aiohttp]", + ], } all_extras = [] diff --git a/testing/constraints-3.9.txt b/testing/constraints-3.9.txt index d4c302867..f4adf95c3 100644 --- a/testing/constraints-3.9.txt +++ b/testing/constraints-3.9.txt @@ -4,5 +4,6 @@ # # NOTE: Not comprehensive yet, will eventually be maintained semi-automatically by # the renovate bot. +aiohttp==3.6.2 grpcio==1.47.0 pyarrow>=4.0.0 diff --git a/tests/unit/test_async_client.py b/tests/unit/test_async_client.py new file mode 100644 index 000000000..ffa5b5f2c --- /dev/null +++ b/tests/unit/test_async_client.py @@ -0,0 +1,750 @@ +import copy +import collections +import datetime +import decimal +import email +import gzip +import http.client +import io +import itertools +import json +import operator +import unittest +import warnings + +from unittest import mock +import requests +import packaging +import pytest +import sys +import inspect +from tests.unit.test_client import _make_list_partitons_meta_info + +if sys.version_info >= (3, 9): + import asyncio + +try: + import importlib.metadata as metadata +except ImportError: + import importlib_metadata as metadata + +try: + import pandas +except (ImportError, AttributeError): # pragma: NO COVER + pandas = None + +try: + import opentelemetry +except ImportError: + opentelemetry = None + +if opentelemetry is not None: + try: + from opentelemetry import trace + from opentelemetry.sdk.trace import TracerProvider + from opentelemetry.sdk.trace.export import SimpleSpanProcessor + from opentelemetry.sdk.trace.export.in_memory_span_exporter import ( + InMemorySpanExporter, + ) + except (ImportError, AttributeError) as exc: # pragma: NO COVER + msg = "Error importing from opentelemetry, is the installed version compatible?" + raise ImportError(msg) from exc + +try: + import pyarrow +except (ImportError, AttributeError): # pragma: NO COVER + pyarrow = None + +import google.api_core.exceptions +from google.api_core import client_info +import google.cloud._helpers +from google.cloud import bigquery + +from google.cloud.bigquery.dataset import DatasetReference +from google.cloud.bigquery import exceptions +from google.cloud.bigquery import ParquetOptions +from google.cloud.bigquery.retry import DEFAULT_TIMEOUT +import google.cloud.bigquery.table + +try: + from google.cloud import bigquery_storage +except (ImportError, AttributeError): # pragma: NO COVER + bigquery_storage = None +from test_utils.imports import maybe_fail_import +from tests.unit.helpers import make_connection + +if pandas is not None: + PANDAS_INSTALLED_VERSION = metadata.version("pandas") +else: + PANDAS_INSTALLED_VERSION = "0.0.0" + +from google.cloud.bigquery.retry import ( + DEFAULT_ASYNC_JOB_RETRY, + DEFAULT_ASYNC_RETRY, + DEFAULT_TIMEOUT, +) +from google.api_core import retry_async as retries +from google.cloud.bigquery import async_client +from google.cloud.bigquery.async_client import AsyncClient +from google.cloud.bigquery.job import query as job_query + + +def asyncio_run(async_func): + def wrapper(*args, **kwargs): + return asyncio.run(async_func(*args, **kwargs)) + + wrapper.__signature__ = inspect.signature( + async_func + ) # without this, fixtures are not injected + + return wrapper + + +def _make_credentials(): + from google.auth import _credentials_async as credentials + + return mock.Mock(spec=credentials.Credentials) + + +class TestClient(unittest.TestCase): + PROJECT = "PROJECT" + DS_ID = "DATASET_ID" + TABLE_ID = "TABLE_ID" + MODEL_ID = "MODEL_ID" + TABLE_REF = DatasetReference(PROJECT, DS_ID).table(TABLE_ID) + KMS_KEY_NAME = "projects/1/locations/us/keyRings/1/cryptoKeys/1" + LOCATION = "us-central" + + @staticmethod + def _get_target_class(): + from google.cloud.bigquery.async_client import AsyncClient + + return AsyncClient + + def _make_one(self, *args, **kw): + return self._get_target_class()(*args, **kw) + + def _make_table_resource(self): + return { + "id": "%s:%s:%s" % (self.PROJECT, self.DS_ID, self.TABLE_ID), + "tableReference": { + "projectId": self.PROJECT, + "datasetId": self.DS_ID, + "tableId": self.TABLE_ID, + }, + } + + @pytest.mark.skipif( + sys.version_info < (3, 9), reason="requires python3.9 or higher" + ) + def test_ctor_defaults(self): + from google.cloud.bigquery._http import Connection + + creds = _make_credentials() + http = object() + client = self._make_one( + project=self.PROJECT, credentials=creds, _http=http + )._client + self.assertIsInstance(client._connection, Connection) + self.assertIs(client._connection.credentials, creds) + self.assertIs(client._connection.http, http) + self.assertIsNone(client.location) + self.assertEqual( + client._connection.API_BASE_URL, Connection.DEFAULT_API_ENDPOINT + ) + + @pytest.mark.skipif( + sys.version_info < (3, 9), reason="requires python3.9 or higher" + ) + def test_ctor_w_empty_client_options(self): + from google.api_core.client_options import ClientOptions + + creds = _make_credentials() + http = object() + client_options = ClientOptions() + client = self._make_one( + project=self.PROJECT, + credentials=creds, + _http=http, + client_options=client_options, + )._client + self.assertEqual( + client._connection.API_BASE_URL, client._connection.DEFAULT_API_ENDPOINT + ) + + @pytest.mark.skipif( + sys.version_info < (3, 9), reason="requires python3.9 or higher" + ) + def test_ctor_w_client_options_dict(self): + creds = _make_credentials() + http = object() + client_options = {"api_endpoint": "https://www.foo-googleapis.com"} + client = self._make_one( + project=self.PROJECT, + credentials=creds, + _http=http, + client_options=client_options, + )._client + self.assertEqual( + client._connection.API_BASE_URL, "https://www.foo-googleapis.com" + ) + + @pytest.mark.skipif( + sys.version_info < (3, 9), reason="requires python3.9 or higher" + ) + def test_ctor_w_client_options_object(self): + from google.api_core.client_options import ClientOptions + + creds = _make_credentials() + http = object() + client_options = ClientOptions(api_endpoint="https://www.foo-googleapis.com") + client = self._make_one( + project=self.PROJECT, + credentials=creds, + _http=http, + client_options=client_options, + )._client + self.assertEqual( + client._connection.API_BASE_URL, "https://www.foo-googleapis.com" + ) + + @pytest.mark.skipif( + packaging.version.parse(getattr(google.api_core, "__version__", "0.0.0")) + < packaging.version.Version("2.15.0"), + reason="universe_domain not supported with google-api-core < 2.15.0", + ) + def test_ctor_w_client_options_universe(self): + creds = _make_credentials() + http = object() + client_options = {"universe_domain": "foo.com"} + client = self._make_one( + project=self.PROJECT, + credentials=creds, + _http=http, + client_options=client_options, + )._client + self.assertEqual(client._connection.API_BASE_URL, "https://bigquery.foo.com") + + @pytest.mark.skipif( + sys.version_info < (3, 9), reason="requires python3.9 or higher" + ) + def test_ctor_w_location(self): + from google.cloud.bigquery._http import Connection + + creds = _make_credentials() + http = object() + location = "us-central" + client = self._make_one( + project=self.PROJECT, credentials=creds, _http=http, location=location + )._client + self.assertIsInstance(client._connection, Connection) + self.assertIs(client._connection.credentials, creds) + self.assertIs(client._connection.http, http) + self.assertEqual(client.location, location) + + @pytest.mark.skipif( + sys.version_info < (3, 9), reason="requires python3.9 or higher" + ) + def test_ctor_w_query_job_config(self): + from google.cloud.bigquery._http import Connection + from google.cloud.bigquery import QueryJobConfig + + creds = _make_credentials() + http = object() + location = "us-central" + job_config = QueryJobConfig() + job_config.dry_run = True + + client = self._make_one( + project=self.PROJECT, + credentials=creds, + _http=http, + location=location, + default_query_job_config=job_config, + )._client + self.assertIsInstance(client._connection, Connection) + self.assertIs(client._connection.credentials, creds) + self.assertIs(client._connection.http, http) + self.assertEqual(client.location, location) + + self.assertIsInstance(client._default_query_job_config, QueryJobConfig) + self.assertTrue(client._default_query_job_config.dry_run) + + @pytest.mark.skipif( + sys.version_info < (3, 9), reason="requires python3.9 or higher" + ) + def test_ctor_w_load_job_config(self): + from google.cloud.bigquery._http import Connection + from google.cloud.bigquery import LoadJobConfig + + creds = _make_credentials() + http = object() + location = "us-central" + job_config = LoadJobConfig() + job_config.create_session = True + + client = self._make_one( + project=self.PROJECT, + credentials=creds, + _http=http, + location=location, + default_load_job_config=job_config, + )._client + self.assertIsInstance(client._connection, Connection) + self.assertIs(client._connection.credentials, creds) + self.assertIs(client._connection.http, http) + self.assertEqual(client.location, location) + + self.assertIsInstance(client._default_load_job_config, LoadJobConfig) + self.assertTrue(client._default_load_job_config.create_session) + + @pytest.mark.skipif( + sys.version_info < (3, 9), reason="requires python3.9 or higher" + ) + @asyncio_run + async def test_get_job_miss_w_explict_project(self): + from google.cloud.exceptions import NotFound + + OTHER_PROJECT = "OTHER_PROJECT" + JOB_ID = "NONESUCH" + creds = _make_credentials() + client = self._make_one(self.PROJECT, creds) + conn = client._client._connection = make_connection() + + with self.assertRaises(NotFound): + await client.get_job(JOB_ID, project=OTHER_PROJECT) + + conn.api_request.assert_called_once_with( + method="GET", + path="/projects/OTHER_PROJECT/jobs/NONESUCH", + query_params={"projection": "full"}, + timeout=DEFAULT_TIMEOUT, + ) + + @pytest.mark.skipif( + sys.version_info < (3, 9), reason="requires python3.9 or higher" + ) + @asyncio_run + async def test_get_job_miss_w_client_location(self): + from google.cloud.exceptions import NotFound + + JOB_ID = "NONESUCH" + creds = _make_credentials() + client = self._make_one("client-proj", creds, location="client-loc") + conn = client._client._connection = make_connection() + + with self.assertRaises(NotFound): + await client.get_job(JOB_ID) + + conn.api_request.assert_called_once_with( + method="GET", + path="/projects/client-proj/jobs/NONESUCH", + query_params={"projection": "full", "location": "client-loc"}, + timeout=DEFAULT_TIMEOUT, + ) + + @pytest.mark.skipif( + sys.version_info < (3, 9), reason="requires python3.9 or higher" + ) + @asyncio_run + async def test_get_job_hit_w_timeout(self): + from google.cloud.bigquery.job import CreateDisposition + from google.cloud.bigquery.job import QueryJob + from google.cloud.bigquery.job import WriteDisposition + + JOB_ID = "query_job" + QUERY_DESTINATION_TABLE = "query_destination_table" + QUERY = "SELECT * from test_dataset:test_table" + ASYNC_QUERY_DATA = { + "id": "{}:{}".format(self.PROJECT, JOB_ID), + "jobReference": { + "projectId": "resource-proj", + "jobId": "query_job", + "location": "us-east1", + }, + "state": "DONE", + "configuration": { + "query": { + "query": QUERY, + "destinationTable": { + "projectId": self.PROJECT, + "datasetId": self.DS_ID, + "tableId": QUERY_DESTINATION_TABLE, + }, + "createDisposition": CreateDisposition.CREATE_IF_NEEDED, + "writeDisposition": WriteDisposition.WRITE_TRUNCATE, + } + }, + } + creds = _make_credentials() + client = self._make_one(self.PROJECT, creds) + conn = client._client._connection = make_connection(ASYNC_QUERY_DATA) + job_from_resource = QueryJob.from_api_repr(ASYNC_QUERY_DATA, client._client) + + job = await client.get_job(job_from_resource, timeout=7.5) + + self.assertIsInstance(job, QueryJob) + self.assertEqual(job.job_id, JOB_ID) + self.assertEqual(job.project, "resource-proj") + self.assertEqual(job.location, "us-east1") + self.assertEqual(job.create_disposition, CreateDisposition.CREATE_IF_NEEDED) + self.assertEqual(job.write_disposition, WriteDisposition.WRITE_TRUNCATE) + + conn.api_request.assert_called_once_with( + method="GET", + path="/projects/resource-proj/jobs/query_job", + query_params={"projection": "full", "location": "us-east1"}, + timeout=7.5, + ) + + @pytest.mark.skipif( + sys.version_info < (3, 9), reason="requires python3.9 or higher" + ) + @asyncio_run + async def test__get_query_results_miss_w_explicit_project_and_timeout(self): + from google.cloud.exceptions import NotFound + + creds = _make_credentials() + client = self._make_one(self.PROJECT, creds) + conn = client._client._connection = make_connection() + path = "/projects/other-project/queries/nothere" + with self.assertRaises(NotFound): + with mock.patch( + "google.cloud.bigquery.opentelemetry_tracing._get_final_span_attributes" + ) as final_attributes: + await client._get_query_results( + "nothere", + None, + project="other-project", + location=self.LOCATION, + timeout_ms=500, + timeout=420, + ) + + final_attributes.assert_called_once_with({"path": path}, client._client, None) + + conn.api_request.assert_called_once_with( + method="GET", + path=path, + query_params={"maxResults": 0, "timeoutMs": 500, "location": self.LOCATION}, + timeout=420, + headers={"X-Server-Timeout": "420"}, + ) + + @pytest.mark.skipif( + sys.version_info < (3, 9), reason="requires python3.9 or higher" + ) + @asyncio_run + async def test__get_query_results_miss_w_short_timeout(self): + import google.cloud.bigquery.client + from google.cloud.exceptions import NotFound + + creds = _make_credentials() + client = self._make_one(self.PROJECT, creds) + conn = client._client._connection = make_connection() + path = "/projects/other-project/queries/nothere" + with self.assertRaises(NotFound): + await client._get_query_results( + "nothere", + None, + project="other-project", + location=self.LOCATION, + timeout_ms=500, + timeout=1, + ) + + conn.api_request.assert_called_once_with( + method="GET", + path=path, + query_params={"maxResults": 0, "timeoutMs": 500, "location": self.LOCATION}, + timeout=google.cloud.bigquery.client._MIN_GET_QUERY_RESULTS_TIMEOUT, + headers={"X-Server-Timeout": "120"}, + ) + + @pytest.mark.skipif( + sys.version_info < (3, 9), reason="requires python3.9 or higher" + ) + @asyncio_run + async def test__get_query_results_miss_w_default_timeout(self): + import google.cloud.bigquery.client + from google.cloud.exceptions import NotFound + + creds = _make_credentials() + client = self._make_one(self.PROJECT, creds) + conn = client._client._connection = make_connection() + path = "/projects/other-project/queries/nothere" + with self.assertRaises(NotFound): + await client._get_query_results( + "nothere", + None, + project="other-project", + location=self.LOCATION, + timeout_ms=500, + timeout=object(), # the api core default timeout + ) + + conn.api_request.assert_called_once_with( + method="GET", + path=path, + query_params={"maxResults": 0, "timeoutMs": 500, "location": self.LOCATION}, + timeout=google.cloud.bigquery.client._MIN_GET_QUERY_RESULTS_TIMEOUT, + headers={"X-Server-Timeout": "120"}, # why is this here? + ) + + @pytest.mark.skipif( + sys.version_info < (3, 9), reason="requires python3.9 or higher" + ) + @asyncio_run + async def test__get_query_results_miss_w_client_location(self): + from google.cloud.exceptions import NotFound + + creds = _make_credentials() + client = self._make_one(self.PROJECT, creds, location=self.LOCATION) + conn = client._client._connection = make_connection() + + with self.assertRaises(NotFound): + await client._get_query_results("nothere", None) + + conn.api_request.assert_called_once_with( + method="GET", + path="/projects/PROJECT/queries/nothere", + query_params={"maxResults": 0, "location": self.LOCATION}, + timeout=DEFAULT_TIMEOUT, + ) + + @pytest.mark.skipif( + sys.version_info < (3, 9), reason="requires python3.9 or higher" + ) + @asyncio_run + async def test__get_query_results_hit(self): + job_id = "query_job" + data = { + "kind": "bigquery#getQueryResultsResponse", + "etag": "some-tag", + "schema": { + "fields": [ + {"name": "title", "type": "STRING", "mode": "NULLABLE"}, + {"name": "unique_words", "type": "INTEGER", "mode": "NULLABLE"}, + ] + }, + "jobReference": {"projectId": self.PROJECT, "jobId": job_id}, + "totalRows": "10", + "totalBytesProcessed": "2464625", + "jobComplete": True, + "cacheHit": False, + } + + creds = _make_credentials() + client = self._make_one(self.PROJECT, creds) + client._client._connection = make_connection(data) + query_results = await client._get_query_results(job_id, None) + + self.assertEqual(query_results.total_rows, 10) + self.assertTrue(query_results.complete) + + @pytest.mark.skipif( + sys.version_info < (3, 9), reason="requires python3.9 or higher" + ) + @asyncio_run + async def test_get_table(self): + path = "projects/%s/datasets/%s/tables/%s" % ( + self.PROJECT, + self.DS_ID, + self.TABLE_ID, + ) + creds = _make_credentials() + http = object() + client = self._make_one(project=self.PROJECT, credentials=creds, _http=http) + resource = self._make_table_resource() + conn = client._client._connection = make_connection(resource) + with mock.patch( + "google.cloud.bigquery.opentelemetry_tracing._get_final_span_attributes" + ) as final_attributes: + table = await client.get_table(self.TABLE_REF, timeout=7.5) + + final_attributes.assert_called_once_with({"path": "/%s" % path}, client, None) + + conn.api_request.assert_called_once_with( + method="GET", path="/%s" % path, timeout=7.5 + ) + self.assertEqual(table.table_id, self.TABLE_ID) + + @pytest.mark.skipif( + sys.version_info < (3, 9), reason="requires python3.9 or higher" + ) + @asyncio_run + async def test_get_table_sets_user_agent(self): + creds = _make_credentials() + http = mock.create_autospec(requests.Session) + mock_response = http.request( + url=mock.ANY, method=mock.ANY, headers=mock.ANY, data=mock.ANY + ) + http.reset_mock() + http.is_mtls = False + mock_response.status_code = 200 + mock_response.json.return_value = self._make_table_resource() + user_agent_override = client_info.ClientInfo(user_agent="my-application/1.2.3") + client = self._make_one( + project=self.PROJECT, + credentials=creds, + client_info=user_agent_override, + _http=http, + ) + + await client.get_table(self.TABLE_REF) + + expected_user_agent = user_agent_override.to_user_agent() + http.request.assert_called_once_with( + url=mock.ANY, + method="GET", + headers={ + "X-Goog-API-Client": expected_user_agent, + "Accept-Encoding": "gzip", + "User-Agent": expected_user_agent, + }, + data=mock.ANY, + timeout=DEFAULT_TIMEOUT, + ) + self.assertIn("my-application/1.2.3", expected_user_agent) + + @pytest.mark.skipif( + sys.version_info < (3, 9), reason="requires python3.9 or higher" + ) + @asyncio_run + async def test_list_partitions(self): + from google.cloud.bigquery.table import Table + + rows = 3 + meta_info = _make_list_partitons_meta_info( + self.PROJECT, self.DS_ID, self.TABLE_ID, rows + ) + + data = { + "totalRows": str(rows), + "rows": [ + {"f": [{"v": "20180101"}]}, + {"f": [{"v": "20180102"}]}, + {"f": [{"v": "20180103"}]}, + ], + } + creds = _make_credentials() + http = object() + client = self._make_one(project=self.PROJECT, credentials=creds, _http=http) + client._client._connection = make_connection(meta_info, data) + table = Table(self.TABLE_REF) + + partition_list = await client.list_partitions(table) + self.assertEqual(len(partition_list), rows) + self.assertIn("20180102", partition_list) + + @pytest.mark.skipif( + sys.version_info < (3, 9), reason="requires python3.9 or higher" + ) + @asyncio_run + async def test_list_partitions_with_string_id(self): + meta_info = _make_list_partitons_meta_info( + self.PROJECT, self.DS_ID, self.TABLE_ID, 0 + ) + + creds = _make_credentials() + http = object() + client = self._make_one(project=self.PROJECT, credentials=creds, _http=http) + client._client._connection = make_connection(meta_info, {}) + + partition_list = await client.list_partitions( + "{}.{}".format(self.DS_ID, self.TABLE_ID) + ) + + self.assertEqual(len(partition_list), 0) + + @pytest.mark.skipif( + sys.version_info < (3, 9), reason="requires python3.9 or higher" + ) + @asyncio_run + async def test__call_api_applying_custom_retry_on_timeout(self): + from concurrent.futures import TimeoutError + from google.cloud.bigquery.retry import DEFAULT_ASYNC_RETRY + + creds = _make_credentials() + client = self._make_one(project=self.PROJECT, credentials=creds) + + api_request_patcher = mock.patch.object( + client._client._connection, + "api_request", + side_effect=[TimeoutError, "result"], + ) + retry = DEFAULT_ASYNC_RETRY.with_deadline(1).with_predicate( + lambda exc: isinstance(exc, TimeoutError) + ) + + with api_request_patcher as fake_api_request: + result = await client._call_api(retry, foo="bar") + + self.assertEqual(result, "result") + self.assertEqual( + fake_api_request.call_args_list, + [mock.call(foo="bar"), mock.call(foo="bar")], # was retried once + ) + + @pytest.mark.skipif( + sys.version_info < (3, 9), reason="requires python3.9 or higher" + ) + @asyncio_run + async def test__call_api_span_creator_not_called(self): + from concurrent.futures import TimeoutError + from google.cloud.bigquery.retry import DEFAULT_ASYNC_RETRY + + creds = _make_credentials() + client = self._make_one(project=self.PROJECT, credentials=creds) + + api_request_patcher = mock.patch.object( + client._client._connection, + "api_request", + side_effect=[TimeoutError, "result"], + ) + retry = DEFAULT_ASYNC_RETRY.with_deadline(1).with_predicate( + lambda exc: isinstance(exc, TimeoutError) + ) + + with api_request_patcher: + with mock.patch( + "google.cloud.bigquery.opentelemetry_tracing._get_final_span_attributes" + ) as final_attributes: + await client._call_api(retry) + + final_attributes.assert_not_called() + + @pytest.mark.skipif( + sys.version_info < (3, 9), reason="requires python3.9 or higher" + ) + @asyncio_run + async def test__call_api_span_creator_called(self): + from concurrent.futures import TimeoutError + from google.cloud.bigquery.retry import DEFAULT_ASYNC_RETRY + + creds = _make_credentials() + client = self._make_one(project=self.PROJECT, credentials=creds) + + api_request_patcher = mock.patch.object( + client._client._connection, + "api_request", + side_effect=[TimeoutError, "result"], + ) + retry = DEFAULT_ASYNC_RETRY.with_deadline(1).with_predicate( + lambda exc: isinstance(exc, TimeoutError) + ) + + with api_request_patcher: + with mock.patch( + "google.cloud.bigquery.opentelemetry_tracing._get_final_span_attributes" + ) as final_attributes: + await client._call_api( + retry, + span_name="test_name", + span_attributes={"test_attribute": "test_attribute-value"}, + ) + + final_attributes.assert_called_once() + +# make tests to show its cancelleable +# row iterator, paginated access, we need to make \ No newline at end of file