Skip to content

Commit 5cd1b07

Browse files
committed
Merge branch 'main' into jco-ga
2 parents 698fdbb + cb646ce commit 5cd1b07

File tree

6 files changed

+187
-187
lines changed

6 files changed

+187
-187
lines changed

google/cloud/bigquery/_pandas_helpers.py

Lines changed: 66 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -508,31 +508,37 @@ def dataframe_to_bq_schema(dataframe, bq_schema):
508508
bq_schema_unused = set()
509509

510510
bq_schema_out = []
511-
unknown_type_fields = []
512-
511+
unknown_type_columns = []
512+
dataframe_reset_index = dataframe.reset_index()
513513
for column, dtype in list_columns_and_indexes(dataframe):
514-
# Use provided type from schema, if present.
514+
# Step 1: use provided type from schema, if present.
515515
bq_field = bq_schema_index.get(column)
516516
if bq_field:
517517
bq_schema_out.append(bq_field)
518518
bq_schema_unused.discard(bq_field.name)
519519
continue
520520

521-
# Otherwise, try to automatically determine the type based on the
521+
# Step 2: try to automatically determine the type based on the
522522
# pandas dtype.
523523
bq_type = _PANDAS_DTYPE_TO_BQ.get(dtype.name)
524524
if bq_type is None:
525-
sample_data = _first_valid(dataframe.reset_index()[column])
525+
sample_data = _first_valid(dataframe_reset_index[column])
526526
if (
527527
isinstance(sample_data, _BaseGeometry)
528528
and sample_data is not None # Paranoia
529529
):
530530
bq_type = "GEOGRAPHY"
531-
bq_field = schema.SchemaField(column, bq_type)
532-
bq_schema_out.append(bq_field)
531+
if bq_type is not None:
532+
bq_schema_out.append(schema.SchemaField(column, bq_type))
533+
continue
534+
535+
# Step 3: try with pyarrow if available
536+
bq_field = _get_schema_by_pyarrow(column, dataframe_reset_index[column])
537+
if bq_field is not None:
538+
bq_schema_out.append(bq_field)
539+
continue
533540

534-
if bq_field.field_type is None:
535-
unknown_type_fields.append(bq_field)
541+
unknown_type_columns.append(column)
536542

537543
# Catch any schema mismatch. The developer explicitly asked to serialize a
538544
# column, but it was not found.
@@ -543,98 +549,70 @@ def dataframe_to_bq_schema(dataframe, bq_schema):
543549
)
544550
)
545551

546-
# If schema detection was not successful for all columns, also try with
547-
# pyarrow, if available.
548-
if unknown_type_fields:
549-
if not pyarrow:
550-
msg = "Could not determine the type of columns: {}".format(
551-
", ".join(field.name for field in unknown_type_fields)
552-
)
553-
warnings.warn(msg)
554-
return None # We cannot detect the schema in full.
555-
556-
# The augment_schema() helper itself will also issue unknown type
557-
# warnings if detection still fails for any of the fields.
558-
bq_schema_out = augment_schema(dataframe, bq_schema_out)
552+
if unknown_type_columns != []:
553+
msg = "Could not determine the type of columns: {}".format(
554+
", ".join(unknown_type_columns)
555+
)
556+
warnings.warn(msg)
557+
return None # We cannot detect the schema in full.
559558

560-
return tuple(bq_schema_out) if bq_schema_out else None
559+
return tuple(bq_schema_out)
561560

562561

563-
def augment_schema(dataframe, current_bq_schema):
564-
"""Try to deduce the unknown field types and return an improved schema.
562+
def _get_schema_by_pyarrow(name, series):
563+
"""Attempt to detect the type of the given series by leveraging PyArrow's
564+
type detection capabilities.
565565
566-
This function requires ``pyarrow`` to run. If all the missing types still
567-
cannot be detected, ``None`` is returned. If all types are already known,
568-
a shallow copy of the given schema is returned.
566+
This function requires the ``pyarrow`` library to be installed and
567+
available. If the series type cannot be determined or ``pyarrow`` is not
568+
available, ``None`` is returned.
569569
570570
Args:
571-
dataframe (pandas.DataFrame):
572-
DataFrame for which some of the field types are still unknown.
573-
current_bq_schema (Sequence[google.cloud.bigquery.schema.SchemaField]):
574-
A BigQuery schema for ``dataframe``. The types of some or all of
575-
the fields may be ``None``.
571+
name (str):
572+
the column name of the SchemaField.
573+
series (pandas.Series):
574+
The Series data for which to detect the data type.
576575
Returns:
577-
Optional[Sequence[google.cloud.bigquery.schema.SchemaField]]
576+
Optional[google.cloud.bigquery.schema.SchemaField]:
577+
A tuple containing the BigQuery-compatible type string (e.g.,
578+
"STRING", "INTEGER", "TIMESTAMP", "DATETIME", "NUMERIC", "BIGNUMERIC")
579+
and the mode string ("NULLABLE", "REPEATED").
580+
Returns ``None`` if the type cannot be determined or ``pyarrow``
581+
is not imported.
578582
"""
579-
# pytype: disable=attribute-error
580-
augmented_schema = []
581-
unknown_type_fields = []
582-
for field in current_bq_schema:
583-
if field.field_type is not None:
584-
augmented_schema.append(field)
585-
continue
586-
587-
arrow_table = pyarrow.array(dataframe.reset_index()[field.name])
588-
589-
if pyarrow.types.is_list(arrow_table.type):
590-
# `pyarrow.ListType`
591-
detected_mode = "REPEATED"
592-
detected_type = _pyarrow_helpers.arrow_scalar_ids_to_bq(
593-
arrow_table.values.type.id
594-
)
595-
596-
# For timezone-naive datetimes, pyarrow assumes the UTC timezone and adds
597-
# it to such datetimes, causing them to be recognized as TIMESTAMP type.
598-
# We thus additionally check the actual data to see if we need to overrule
599-
# that and choose DATETIME instead.
600-
# Note that this should only be needed for datetime values inside a list,
601-
# since scalar datetime values have a proper Pandas dtype that allows
602-
# distinguishing between timezone-naive and timezone-aware values before
603-
# even requiring the additional schema augment logic in this method.
604-
if detected_type == "TIMESTAMP":
605-
valid_item = _first_array_valid(dataframe[field.name])
606-
if isinstance(valid_item, datetime) and valid_item.tzinfo is None:
607-
detected_type = "DATETIME"
608-
else:
609-
detected_mode = field.mode
610-
detected_type = _pyarrow_helpers.arrow_scalar_ids_to_bq(arrow_table.type.id)
611-
if detected_type == "NUMERIC" and arrow_table.type.scale > 9:
612-
detected_type = "BIGNUMERIC"
613583

614-
if detected_type is None:
615-
unknown_type_fields.append(field)
616-
continue
584+
if not pyarrow:
585+
return None
617586

618-
new_field = schema.SchemaField(
619-
name=field.name,
620-
field_type=detected_type,
621-
mode=detected_mode,
622-
description=field.description,
623-
fields=field.fields,
624-
)
625-
augmented_schema.append(new_field)
587+
arrow_table = pyarrow.array(series)
588+
if pyarrow.types.is_list(arrow_table.type):
589+
# `pyarrow.ListType`
590+
mode = "REPEATED"
591+
type = _pyarrow_helpers.arrow_scalar_ids_to_bq(arrow_table.values.type.id)
592+
593+
# For timezone-naive datetimes, pyarrow assumes the UTC timezone and adds
594+
# it to such datetimes, causing them to be recognized as TIMESTAMP type.
595+
# We thus additionally check the actual data to see if we need to overrule
596+
# that and choose DATETIME instead.
597+
# Note that this should only be needed for datetime values inside a list,
598+
# since scalar datetime values have a proper Pandas dtype that allows
599+
# distinguishing between timezone-naive and timezone-aware values before
600+
# even requiring the additional schema augment logic in this method.
601+
if type == "TIMESTAMP":
602+
valid_item = _first_array_valid(series)
603+
if isinstance(valid_item, datetime) and valid_item.tzinfo is None:
604+
type = "DATETIME"
605+
else:
606+
mode = "NULLABLE" # default mode
607+
type = _pyarrow_helpers.arrow_scalar_ids_to_bq(arrow_table.type.id)
608+
if type == "NUMERIC" and arrow_table.type.scale > 9:
609+
type = "BIGNUMERIC"
626610

627-
if unknown_type_fields:
628-
warnings.warn(
629-
"Pyarrow could not determine the type of columns: {}.".format(
630-
", ".join(field.name for field in unknown_type_fields)
631-
)
632-
)
611+
if type is not None:
612+
return schema.SchemaField(name, type, mode)
613+
else:
633614
return None
634615

635-
return augmented_schema
636-
# pytype: enable=attribute-error
637-
638616

639617
def dataframe_to_arrow(dataframe, bq_schema):
640618
"""Convert pandas dataframe to Arrow table, using BigQuery schema.

google/cloud/bigquery/job/base.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,26 @@ def job_timeout_ms(self, value):
224224
else:
225225
self._properties.pop("jobTimeoutMs", None)
226226

227+
@property
228+
def reservation(self):
229+
"""str: Optional. The reservation that job would use.
230+
231+
User can specify a reservation to execute the job. If reservation is
232+
not set, reservation is determined based on the rules defined by the
233+
reservation assignments. The expected format is
234+
projects/{project}/locations/{location}/reservations/{reservation}.
235+
236+
Raises:
237+
ValueError: If ``value`` type is not None or of string type.
238+
"""
239+
return self._properties.setdefault("reservation", None)
240+
241+
@reservation.setter
242+
def reservation(self, value):
243+
if value and not isinstance(value, str):
244+
raise ValueError("Reservation must be None or a string.")
245+
self._properties["reservation"] = value
246+
227247
@property
228248
def labels(self):
229249
"""Dict[str, str]: Labels for the job.
@@ -488,6 +508,18 @@ def location(self):
488508
"""str: Location where the job runs."""
489509
return _helpers._get_sub_prop(self._properties, ["jobReference", "location"])
490510

511+
@property
512+
def reservation_id(self):
513+
"""str: Name of the primary reservation assigned to this job.
514+
515+
Note that this could be different than reservations reported in
516+
the reservation field if parent reservations were used to execute
517+
this job.
518+
"""
519+
return _helpers._get_sub_prop(
520+
self._properties, ["statistics", "reservation_id"]
521+
)
522+
491523
def _require_client(self, client):
492524
"""Check client or verify over-ride.
493525

google/cloud/bigquery/schema.py

Lines changed: 11 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -284,15 +284,13 @@ def name(self):
284284
return self._properties.get("name", "")
285285

286286
@property
287-
def field_type(self):
287+
def field_type(self) -> str:
288288
"""str: The type of the field.
289289
290290
See:
291291
https://cloud.google.com/bigquery/docs/reference/rest/v2/tables#TableFieldSchema.FIELDS.type
292292
"""
293293
type_ = self._properties.get("type")
294-
if type_ is None: # Shouldn't happen, but some unit tests do this.
295-
return None
296294
return cast(str, type_).upper()
297295

298296
@property
@@ -397,20 +395,16 @@ def _key(self):
397395
Returns:
398396
Tuple: The contents of this :class:`~google.cloud.bigquery.schema.SchemaField`.
399397
"""
400-
field_type = self.field_type.upper() if self.field_type is not None else None
401-
402-
# Type can temporarily be set to None if the code needs a SchemaField instance,
403-
# but has not determined the exact type of the field yet.
404-
if field_type is not None:
405-
if field_type == "STRING" or field_type == "BYTES":
406-
if self.max_length is not None:
407-
field_type = f"{field_type}({self.max_length})"
408-
elif field_type.endswith("NUMERIC"):
409-
if self.precision is not None:
410-
if self.scale is not None:
411-
field_type = f"{field_type}({self.precision}, {self.scale})"
412-
else:
413-
field_type = f"{field_type}({self.precision})"
398+
field_type = self.field_type
399+
if field_type == "STRING" or field_type == "BYTES":
400+
if self.max_length is not None:
401+
field_type = f"{field_type}({self.max_length})"
402+
elif field_type.endswith("NUMERIC"):
403+
if self.precision is not None:
404+
if self.scale is not None:
405+
field_type = f"{field_type}({self.precision}, {self.scale})"
406+
else:
407+
field_type = f"{field_type}({self.precision})"
414408

415409
policy_tags = (
416410
None if self.policy_tags is None else tuple(sorted(self.policy_tags.names))

tests/unit/job/test_base.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -443,6 +443,16 @@ def test_state(self):
443443
status["state"] = state
444444
self.assertEqual(job.state, state)
445445

446+
def test_reservation_id(self):
447+
reservation_id = "RESERVATION-ID"
448+
client = _make_client(project=self.PROJECT)
449+
job = self._make_one(self.JOB_ID, client)
450+
self.assertIsNone(job.reservation_id)
451+
stats = job._properties["statistics"] = {}
452+
self.assertIsNone(job.reservation_id)
453+
stats["reservation_id"] = reservation_id
454+
self.assertEqual(job.reservation_id, reservation_id)
455+
446456
def _set_properties_job(self):
447457
client = _make_client(project=self.PROJECT)
448458
job = self._make_one(self.JOB_ID, client)
@@ -1188,31 +1198,37 @@ def test_fill_query_job_config_from_default(self):
11881198
job_config = QueryJobConfig()
11891199
job_config.dry_run = True
11901200
job_config.maximum_bytes_billed = 1000
1201+
job_config.reservation = "reservation_1"
11911202

11921203
default_job_config = QueryJobConfig()
11931204
default_job_config.use_query_cache = True
11941205
default_job_config.maximum_bytes_billed = 2000
1206+
default_job_config.reservation = "reservation_2"
11951207

11961208
final_job_config = job_config._fill_from_default(default_job_config)
11971209
self.assertTrue(final_job_config.dry_run)
11981210
self.assertTrue(final_job_config.use_query_cache)
11991211
self.assertEqual(final_job_config.maximum_bytes_billed, 1000)
1212+
self.assertEqual(final_job_config.reservation, "reservation_1")
12001213

12011214
def test_fill_load_job_from_default(self):
12021215
from google.cloud.bigquery import LoadJobConfig
12031216

12041217
job_config = LoadJobConfig()
12051218
job_config.create_session = True
12061219
job_config.encoding = "UTF-8"
1220+
job_config.reservation = "reservation_1"
12071221

12081222
default_job_config = LoadJobConfig()
12091223
default_job_config.ignore_unknown_values = True
12101224
default_job_config.encoding = "ISO-8859-1"
1225+
default_job_config.reservation = "reservation_2"
12111226

12121227
final_job_config = job_config._fill_from_default(default_job_config)
12131228
self.assertTrue(final_job_config.create_session)
12141229
self.assertTrue(final_job_config.ignore_unknown_values)
12151230
self.assertEqual(final_job_config.encoding, "UTF-8")
1231+
self.assertEqual(final_job_config.reservation, "reservation_1")
12161232

12171233
def test_fill_from_default_conflict(self):
12181234
from google.cloud.bigquery import QueryJobConfig
@@ -1232,10 +1248,12 @@ def test_fill_from_empty_default_conflict(self):
12321248
job_config = QueryJobConfig()
12331249
job_config.dry_run = True
12341250
job_config.maximum_bytes_billed = 1000
1251+
job_config.reservation = "reservation_1"
12351252

12361253
final_job_config = job_config._fill_from_default(default_job_config=None)
12371254
self.assertTrue(final_job_config.dry_run)
12381255
self.assertEqual(final_job_config.maximum_bytes_billed, 1000)
1256+
self.assertEqual(final_job_config.reservation, "reservation_1")
12391257

12401258
@mock.patch("google.cloud.bigquery._helpers._get_sub_prop")
12411259
def test__get_sub_prop_wo_default(self, _get_sub_prop):
@@ -1338,3 +1356,27 @@ def test_job_timeout_properties(self):
13381356
job_config.job_timeout_ms = None
13391357
assert job_config.job_timeout_ms is None
13401358
assert "jobTimeoutMs" not in job_config._properties
1359+
1360+
def test_reservation_miss(self):
1361+
job_config = self._make_one()
1362+
self.assertEqual(job_config.reservation, None)
1363+
1364+
def test_reservation_hit(self):
1365+
job_config = self._make_one()
1366+
job_config._properties["reservation"] = "foo"
1367+
self.assertEqual(job_config.reservation, "foo")
1368+
1369+
def test_reservation_update_in_place(self):
1370+
job_config = self._make_one()
1371+
job_config.reservation = "bar" # update in place
1372+
self.assertEqual(job_config.reservation, "bar")
1373+
1374+
def test_reservation_setter_invalid(self):
1375+
job_config = self._make_one()
1376+
with self.assertRaises(ValueError):
1377+
job_config.reservation = object()
1378+
1379+
def test_reservation_setter(self):
1380+
job_config = self._make_one()
1381+
job_config.reservation = "foo"
1382+
self.assertEqual(job_config._properties["reservation"], "foo")

0 commit comments

Comments
 (0)