diff --git a/google/cloud/bigquery/_helpers.py b/google/cloud/bigquery/_helpers.py index 1eda80712..11dc0bebe 100644 --- a/google/cloud/bigquery/_helpers.py +++ b/google/cloud/bigquery/_helpers.py @@ -389,7 +389,7 @@ def default_converter(value, field): return converter(resource, field) -def _row_tuple_from_json(row, schema): +def _row_tuple_from_json(row, schema, types_mapper=None): """Convert JSON row data to row with appropriate types. Note: ``row['f']`` and ``schema`` are presumed to be of the same length. @@ -406,7 +406,7 @@ def _row_tuple_from_json(row, schema): """ from google.cloud.bigquery.schema import _to_schema_fields - schema = _to_schema_fields(schema) + schema = _to_schema_fields(schema, types_mapper) row_data = [] for field, cell in zip(schema, row["f"]): diff --git a/google/cloud/bigquery/_pandas_helpers.py b/google/cloud/bigquery/_pandas_helpers.py index bf7d10c0f..37456f550 100644 --- a/google/cloud/bigquery/_pandas_helpers.py +++ b/google/cloud/bigquery/_pandas_helpers.py @@ -711,7 +711,7 @@ def _row_iterator_page_to_arrow(page, column_names, arrow_types): return pyarrow.RecordBatch.from_arrays(arrays, names=column_names) -def download_arrow_row_iterator(pages, bq_schema): +def download_arrow_row_iterator(pages, bq_schema, types_mapper=None): """Use HTTP JSON RowIterator to construct an iterable of RecordBatches. Args: @@ -726,7 +726,7 @@ def download_arrow_row_iterator(pages, bq_schema): :class:`pyarrow.RecordBatch` The next page of records as a ``pyarrow`` record batch. """ - bq_schema = schema._to_schema_fields(bq_schema) + bq_schema = schema._to_schema_fields(bq_schema, types_mapper) column_names = bq_to_arrow_schema(bq_schema) or [field.name for field in bq_schema] arrow_types = [bq_to_arrow_data_type(field) for field in bq_schema] diff --git a/google/cloud/bigquery/client.py b/google/cloud/bigquery/client.py index 97f239f7a..e14edc5d3 100644 --- a/google/cloud/bigquery/client.py +++ b/google/cloud/bigquery/client.py @@ -31,6 +31,7 @@ import typing from typing import ( Any, + Callable, Dict, IO, Iterable, @@ -219,6 +220,9 @@ class Client(ClientWithProject): client_options (Optional[Union[google.api_core.client_options.ClientOptions, Dict]]): Client options used to set user options on the client. API Endpoint should be set through client_options. + types_mapper (typing.Callable): + Client options used to set user options on the client. API Endpoint + should be set through client_options. Raises: google.auth.exceptions.DefaultCredentialsError: @@ -239,6 +243,8 @@ def __init__( default_load_job_config=None, client_info=None, client_options=None, + *, + types_mapper=None, ) -> None: super(Client, self).__init__( project=project, @@ -275,6 +281,9 @@ def __init__( # Use property setter so validation can run. self.default_query_job_config = default_query_job_config + # Client level types mapper setting. + self._types_mapper = types_mapper + @property def location(self): """Default location for jobs / datasets / tables.""" @@ -308,6 +317,17 @@ def default_load_job_config(self): def default_load_job_config(self, value: LoadJobConfig): self._default_load_job_config = copy.deepcopy(value) + + @property + def types_mapper(self): + """TODO: add docstring + """ + return self._types_mapper + + @types_mapper.setter + def types_mapper(self, value: Optional[Callable]): + self._types_mapper = value + def close(self): """Close the underlying transport objects, releasing system resources. diff --git a/google/cloud/bigquery/schema.py b/google/cloud/bigquery/schema.py index f5b03cbef..66f1cb9af 100644 --- a/google/cloud/bigquery/schema.py +++ b/google/cloud/bigquery/schema.py @@ -15,6 +15,7 @@ """Schemas for BigQuery tables / queries.""" import collections +import copy import enum from typing import Any, Dict, Iterable, Optional, Union, cast @@ -473,7 +474,7 @@ def _build_schema_resource(fields): return [field.to_api_repr() for field in fields] -def _to_schema_fields(schema): +def _to_schema_fields(schema, types_mapper=None): """Coerce `schema` to a list of schema field instances. Args: @@ -493,17 +494,25 @@ def _to_schema_fields(schema): sequence is not a :class:`~google.cloud.bigquery.schema.SchemaField` instance or a compatible mapping representation of the field. """ + schema_fields = [] for field in schema: - if not isinstance(field, (SchemaField, collections.abc.Mapping)): + if isinstance(field, SchemaField): + current_field = copy.deepcopy(field) + field_name = field.name + elif isinstance(field, collections.abc.Mapping): + current_field = SchemaField.from_api_repr(field) + field_name = field["name"] + else: raise ValueError( "Schema items must either be fields or compatible " "mapping representations." - ) + ) + + if types_mapper and types_mapper(field_name): + current_field._properties["type"] = types_mapper(field_name) - return [ - field if isinstance(field, SchemaField) else SchemaField.from_api_repr(field) - for field in schema - ] + schema_fields.append(current_field) + return schema_fields class PolicyTagList(object): diff --git a/google/cloud/bigquery/table.py b/google/cloud/bigquery/table.py index dcaf377e3..304d7f6b4 100644 --- a/google/cloud/bigquery/table.py +++ b/google/cloud/bigquery/table.py @@ -1594,11 +1594,21 @@ def __init__( project: Optional[str] = None, num_dml_affected_rows: Optional[int] = None, ): + if client: + types_mapper = client.types_mapper + else: + types_mapper = None + + if types_mapper: + _item_to_row_with_mapper = functools.partial(_item_to_row, types_mapper=types_mapper) + else: + _item_to_row_with_mapper = _item_to_row + super(RowIterator, self).__init__( client, api_request, path, - item_to_value=_item_to_row, + item_to_value=_item_to_row_with_mapper, items_key="rows", page_token=page_token, max_results=max_results, @@ -1606,7 +1616,7 @@ def __init__( page_start=_rows_page_start, next_token="pageToken", ) - schema = _to_schema_fields(schema) + schema = _to_schema_fields(schema, types_mapper) self._field_to_index = _helpers._field_to_index_mapping(schema) self._page_size = page_size self._preserve_order = False @@ -1871,8 +1881,12 @@ def to_arrow_iterable( max_queue_size=max_queue_size, max_stream_count=max_stream_count, ) + if self.client is not None: + types_mapper = self.client.types_mapper + else: + types_mapper = None tabledata_list_download = functools.partial( - _pandas_helpers.download_arrow_row_iterator, iter(self.pages), self.schema + _pandas_helpers.download_arrow_row_iterator, iter(self.pages), self.schema, types_mapper=types_mapper, ) return self._to_page_iterable( bqstorage_download, @@ -3262,7 +3276,7 @@ def from_api_repr(cls, resource: Dict[str, Any]) -> "TableConstraints": return cls(primary_key, foreign_keys) -def _item_to_row(iterator, resource): +def _item_to_row(iterator, resource, types_mapper=None): """Convert a JSON row to the native object. .. note:: @@ -3279,7 +3293,7 @@ def _item_to_row(iterator, resource): google.cloud.bigquery.table.Row: The next row in the page. """ return Row( - _helpers._row_tuple_from_json(resource, iterator.schema), + _helpers._row_tuple_from_json(resource, iterator.schema, types_mapper), iterator._field_to_index, ) diff --git a/tests/unit/test_table.py b/tests/unit/test_table.py index d81ad2dca..fa083baea 100644 --- a/tests/unit/test_table.py +++ b/tests/unit/test_table.py @@ -36,8 +36,10 @@ def _mock_client(): from google.cloud.bigquery import client - mock_client = mock.create_autospec(client.Client) - mock_client.project = "my-project" + mock_client = client.Client(project="my-project") + mock_client._ensure_bqstorage_client = mock.MagicMock( + mock_client._ensure_bqstorage_client, + ) return mock_client