From f343e0fa2217c074e9d74a7e14862fa8dd9e66b1 Mon Sep 17 00:00:00 2001 From: Linchin Date: Thu, 3 Oct 2024 14:28:36 -0700 Subject: [PATCH 1/3] mapper proto2 --- google/cloud/bigquery/_helpers.py | 4 ++-- google/cloud/bigquery/_pandas_helpers.py | 4 ++-- google/cloud/bigquery/client.py | 8 ++++++++ google/cloud/bigquery/schema.py | 14 +++++++++++++- google/cloud/bigquery/table.py | 11 ++++++----- 5 files changed, 31 insertions(+), 10 deletions(-) diff --git a/google/cloud/bigquery/_helpers.py b/google/cloud/bigquery/_helpers.py index 1eda80712..499448a98 100644 --- a/google/cloud/bigquery/_helpers.py +++ b/google/cloud/bigquery/_helpers.py @@ -389,7 +389,7 @@ def default_converter(value, field): return converter(resource, field) -def _row_tuple_from_json(row, schema): +def _row_tuple_from_json(row, schema, types_mapper): """Convert JSON row data to row with appropriate types. Note: ``row['f']`` and ``schema`` are presumed to be of the same length. @@ -406,7 +406,7 @@ def _row_tuple_from_json(row, schema): """ from google.cloud.bigquery.schema import _to_schema_fields - schema = _to_schema_fields(schema) + schema = _to_schema_fields(schema, types_mapper) row_data = [] for field, cell in zip(schema, row["f"]): diff --git a/google/cloud/bigquery/_pandas_helpers.py b/google/cloud/bigquery/_pandas_helpers.py index 210ab4875..323027bc1 100644 --- a/google/cloud/bigquery/_pandas_helpers.py +++ b/google/cloud/bigquery/_pandas_helpers.py @@ -710,7 +710,7 @@ def _row_iterator_page_to_arrow(page, column_names, arrow_types): return pyarrow.RecordBatch.from_arrays(arrays, names=column_names) -def download_arrow_row_iterator(pages, bq_schema): +def download_arrow_row_iterator(pages, bq_schema, types_mapper=None): """Use HTTP JSON RowIterator to construct an iterable of RecordBatches. Args: @@ -725,7 +725,7 @@ def download_arrow_row_iterator(pages, bq_schema): :class:`pyarrow.RecordBatch` The next page of records as a ``pyarrow`` record batch. """ - bq_schema = schema._to_schema_fields(bq_schema) + bq_schema = schema._to_schema_fields(bq_schema, types_mapper) column_names = bq_to_arrow_schema(bq_schema) or [field.name for field in bq_schema] arrow_types = [bq_to_arrow_data_type(field) for field in bq_schema] diff --git a/google/cloud/bigquery/client.py b/google/cloud/bigquery/client.py index 1c222f2dd..dd8329ea7 100644 --- a/google/cloud/bigquery/client.py +++ b/google/cloud/bigquery/client.py @@ -219,6 +219,9 @@ class Client(ClientWithProject): client_options (Optional[Union[google.api_core.client_options.ClientOptions, Dict]]): Client options used to set user options on the client. API Endpoint should be set through client_options. + types_mapper (typing.Callable): + Client options used to set user options on the client. API Endpoint + should be set through client_options. Raises: google.auth.exceptions.DefaultCredentialsError: @@ -239,6 +242,8 @@ def __init__( default_load_job_config=None, client_info=None, client_options=None, + *, + types_mapper=None, ) -> None: super(Client, self).__init__( project=project, @@ -275,6 +280,9 @@ def __init__( # Use property setter so validation can run. self.default_query_job_config = default_query_job_config + # Client level types mapper setting. + self._types_mapper = types_mapper + @property def location(self): """Default location for jobs / datasets / tables.""" diff --git a/google/cloud/bigquery/schema.py b/google/cloud/bigquery/schema.py index f5b03cbef..e45c77b03 100644 --- a/google/cloud/bigquery/schema.py +++ b/google/cloud/bigquery/schema.py @@ -15,6 +15,7 @@ """Schemas for BigQuery tables / queries.""" import collections +import copy import enum from typing import Any, Dict, Iterable, Optional, Union, cast @@ -473,7 +474,7 @@ def _build_schema_resource(fields): return [field.to_api_repr() for field in fields] -def _to_schema_fields(schema): +def _to_schema_fields(schema, types_mapper): """Coerce `schema` to a list of schema field instances. Args: @@ -493,12 +494,23 @@ def _to_schema_fields(schema): sequence is not a :class:`~google.cloud.bigquery.schema.SchemaField` instance or a compatible mapping representation of the field. """ + schema_fields = [] for field in schema: if not isinstance(field, (SchemaField, collections.abc.Mapping)): raise ValueError( "Schema items must either be fields or compatible " "mapping representations." ) + if types_mapper and types_mapper(field.name): + custom_field = copy.deepcopy(field) + custom_field._properties["type"] = types_mapper(field.name) + schema_fields.append(custom_field) + elif isinstance(field, SchemaField): + schema_fields.append(field) + else: + schema_fields.append(SchemaField.from_api_repr(field)) + + return schema_fields return [ field if isinstance(field, SchemaField) else SchemaField.from_api_repr(field) diff --git a/google/cloud/bigquery/table.py b/google/cloud/bigquery/table.py index faf827be4..3ee24e613 100644 --- a/google/cloud/bigquery/table.py +++ b/google/cloud/bigquery/table.py @@ -1594,11 +1594,12 @@ def __init__( project: Optional[str] = None, num_dml_affected_rows: Optional[int] = None, ): + _item_to_row_with_mapper = functools.partial(_item_to_row, types_mapper=client._types_mapper) super(RowIterator, self).__init__( client, api_request, path, - item_to_value=_item_to_row, + item_to_value=_item_to_row_with_mapper, items_key="rows", page_token=page_token, max_results=max_results, @@ -1606,7 +1607,7 @@ def __init__( page_start=_rows_page_start, next_token="pageToken", ) - schema = _to_schema_fields(schema) + schema = _to_schema_fields(schema, client._types_mapper) self._field_to_index = _helpers._field_to_index_mapping(schema) self._page_size = page_size self._preserve_order = False @@ -1854,7 +1855,7 @@ def to_arrow_iterable( max_queue_size=max_queue_size, ) tabledata_list_download = functools.partial( - _pandas_helpers.download_arrow_row_iterator, iter(self.pages), self.schema + _pandas_helpers.download_arrow_row_iterator, iter(self.pages), self.schema, types_mapper=self.client._types_mapper, ) return self._to_page_iterable( bqstorage_download, @@ -3218,7 +3219,7 @@ def from_api_repr(cls, resource: Dict[str, Any]) -> "TableConstraints": return cls(primary_key, foreign_keys) -def _item_to_row(iterator, resource): +def _item_to_row(iterator, resource, types_mapper): """Convert a JSON row to the native object. .. note:: @@ -3235,7 +3236,7 @@ def _item_to_row(iterator, resource): google.cloud.bigquery.table.Row: The next row in the page. """ return Row( - _helpers._row_tuple_from_json(resource, iterator.schema), + _helpers._row_tuple_from_json(resource, iterator.schema, types_mapper), iterator._field_to_index, ) From d906c39e82935363f2590b60a42153a5cabf53fc Mon Sep 17 00:00:00 2001 From: Linchin Date: Mon, 2 Dec 2024 15:44:21 -0800 Subject: [PATCH 2/3] update logic and fix unit tests --- google/cloud/bigquery/_helpers.py | 2 +- google/cloud/bigquery/client.py | 12 ++++++++++++ google/cloud/bigquery/schema.py | 29 +++++++++++++---------------- google/cloud/bigquery/table.py | 22 ++++++++++++++++++---- tests/unit/test_table.py | 7 +++++-- 5 files changed, 49 insertions(+), 23 deletions(-) diff --git a/google/cloud/bigquery/_helpers.py b/google/cloud/bigquery/_helpers.py index 499448a98..11dc0bebe 100644 --- a/google/cloud/bigquery/_helpers.py +++ b/google/cloud/bigquery/_helpers.py @@ -389,7 +389,7 @@ def default_converter(value, field): return converter(resource, field) -def _row_tuple_from_json(row, schema, types_mapper): +def _row_tuple_from_json(row, schema, types_mapper=None): """Convert JSON row data to row with appropriate types. Note: ``row['f']`` and ``schema`` are presumed to be of the same length. diff --git a/google/cloud/bigquery/client.py b/google/cloud/bigquery/client.py index 8e2e3bd86..e14edc5d3 100644 --- a/google/cloud/bigquery/client.py +++ b/google/cloud/bigquery/client.py @@ -31,6 +31,7 @@ import typing from typing import ( Any, + Callable, Dict, IO, Iterable, @@ -316,6 +317,17 @@ def default_load_job_config(self): def default_load_job_config(self, value: LoadJobConfig): self._default_load_job_config = copy.deepcopy(value) + + @property + def types_mapper(self): + """TODO: add docstring + """ + return self._types_mapper + + @types_mapper.setter + def types_mapper(self, value: Optional[Callable]): + self._types_mapper = value + def close(self): """Close the underlying transport objects, releasing system resources. diff --git a/google/cloud/bigquery/schema.py b/google/cloud/bigquery/schema.py index e45c77b03..66f1cb9af 100644 --- a/google/cloud/bigquery/schema.py +++ b/google/cloud/bigquery/schema.py @@ -474,7 +474,7 @@ def _build_schema_resource(fields): return [field.to_api_repr() for field in fields] -def _to_schema_fields(schema, types_mapper): +def _to_schema_fields(schema, types_mapper=None): """Coerce `schema` to a list of schema field instances. Args: @@ -496,26 +496,23 @@ def _to_schema_fields(schema, types_mapper): """ schema_fields = [] for field in schema: - if not isinstance(field, (SchemaField, collections.abc.Mapping)): + if isinstance(field, SchemaField): + current_field = copy.deepcopy(field) + field_name = field.name + elif isinstance(field, collections.abc.Mapping): + current_field = SchemaField.from_api_repr(field) + field_name = field["name"] + else: raise ValueError( "Schema items must either be fields or compatible " "mapping representations." - ) - if types_mapper and types_mapper(field.name): - custom_field = copy.deepcopy(field) - custom_field._properties["type"] = types_mapper(field.name) - schema_fields.append(custom_field) - elif isinstance(field, SchemaField): - schema_fields.append(field) - else: - schema_fields.append(SchemaField.from_api_repr(field)) + ) - return schema_fields + if types_mapper and types_mapper(field_name): + current_field._properties["type"] = types_mapper(field_name) - return [ - field if isinstance(field, SchemaField) else SchemaField.from_api_repr(field) - for field in schema - ] + schema_fields.append(current_field) + return schema_fields class PolicyTagList(object): diff --git a/google/cloud/bigquery/table.py b/google/cloud/bigquery/table.py index 3ee24e613..7493f48a9 100644 --- a/google/cloud/bigquery/table.py +++ b/google/cloud/bigquery/table.py @@ -1594,7 +1594,17 @@ def __init__( project: Optional[str] = None, num_dml_affected_rows: Optional[int] = None, ): - _item_to_row_with_mapper = functools.partial(_item_to_row, types_mapper=client._types_mapper) + if client: + types_mapper = client.types_mapper + else: + types_mapper = None + + if types_mapper: + _item_to_row_with_mapper = functools.partial(_item_to_row, types_mapper=types_mapper) + else: + _item_to_row_with_mapper = _item_to_row + + # breakpoint() super(RowIterator, self).__init__( client, api_request, @@ -1607,7 +1617,7 @@ def __init__( page_start=_rows_page_start, next_token="pageToken", ) - schema = _to_schema_fields(schema, client._types_mapper) + schema = _to_schema_fields(schema, types_mapper) self._field_to_index = _helpers._field_to_index_mapping(schema) self._page_size = page_size self._preserve_order = False @@ -1854,8 +1864,12 @@ def to_arrow_iterable( selected_fields=self._selected_fields, max_queue_size=max_queue_size, ) + if self.client is not None: + types_mapper = self.client.types_mapper + else: + types_mapper = None tabledata_list_download = functools.partial( - _pandas_helpers.download_arrow_row_iterator, iter(self.pages), self.schema, types_mapper=self.client._types_mapper, + _pandas_helpers.download_arrow_row_iterator, iter(self.pages), self.schema, types_mapper=types_mapper, ) return self._to_page_iterable( bqstorage_download, @@ -3219,7 +3233,7 @@ def from_api_repr(cls, resource: Dict[str, Any]) -> "TableConstraints": return cls(primary_key, foreign_keys) -def _item_to_row(iterator, resource, types_mapper): +def _item_to_row(iterator, resource, types_mapper=None): """Convert a JSON row to the native object. .. note:: diff --git a/tests/unit/test_table.py b/tests/unit/test_table.py index 018a096df..1500abde1 100644 --- a/tests/unit/test_table.py +++ b/tests/unit/test_table.py @@ -36,8 +36,10 @@ def _mock_client(): from google.cloud.bigquery import client - mock_client = mock.create_autospec(client.Client) - mock_client.project = "my-project" + mock_client = client.Client(project="my-project") + mock_client._ensure_bqstorage_client = mock.MagicMock( + mock_client._ensure_bqstorage_client, + ) return mock_client @@ -2085,6 +2087,7 @@ def test_constructor_with_dict_schema(self): ] iterator = self._make_one(schema=schema) + #breakpoint() expected_schema = [ SchemaField("full_name", "STRING", mode="REQUIRED"), From ef7686d83c4a35cf984a008021d9f24a7e6cc839 Mon Sep 17 00:00:00 2001 From: Linchin Date: Mon, 2 Dec 2024 15:48:54 -0800 Subject: [PATCH 3/3] remove breakpoints --- google/cloud/bigquery/table.py | 1 - tests/unit/test_table.py | 1 - 2 files changed, 2 deletions(-) diff --git a/google/cloud/bigquery/table.py b/google/cloud/bigquery/table.py index 7493f48a9..6435c6114 100644 --- a/google/cloud/bigquery/table.py +++ b/google/cloud/bigquery/table.py @@ -1604,7 +1604,6 @@ def __init__( else: _item_to_row_with_mapper = _item_to_row - # breakpoint() super(RowIterator, self).__init__( client, api_request, diff --git a/tests/unit/test_table.py b/tests/unit/test_table.py index 1500abde1..dc02d5787 100644 --- a/tests/unit/test_table.py +++ b/tests/unit/test_table.py @@ -2087,7 +2087,6 @@ def test_constructor_with_dict_schema(self): ] iterator = self._make_one(schema=schema) - #breakpoint() expected_schema = [ SchemaField("full_name", "STRING", mode="REQUIRED"),