Skip to content

Commit bb1b1e3

Browse files
refactor: Add mappings from internal dtypes to bq types (#810)
1 parent 6d947a2 commit bb1b1e3

File tree

3 files changed

+61
-20
lines changed

3 files changed

+61
-20
lines changed

bigframes/core/blocks.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -444,27 +444,24 @@ def _to_dataframe(self, result) -> pd.DataFrame:
444444
# Runs strict validations to ensure internal type predictions and ibis are completely in sync
445445
# Do not execute these validations outside of testing suite.
446446
if "PYTEST_CURRENT_TEST" in os.environ:
447-
self._validate_result_schema(result_dataframe)
447+
self._validate_result_schema(result.schema)
448448
return result_dataframe
449449

450-
def _validate_result_schema(self, result_df: pd.DataFrame):
450+
def _validate_result_schema(
451+
self, bq_result_schema: list[bigquery.schema.SchemaField]
452+
):
453+
actual_schema = tuple(bq_result_schema)
451454
ibis_schema = self.expr._compiled_schema
452455
internal_schema = self.expr.node.schema
453-
actual_schema = bf_schema.ArraySchema(
454-
tuple(
455-
bf_schema.SchemaItem(name, dtype) # type: ignore
456-
for name, dtype in result_df.dtypes.items()
457-
)
458-
)
459456
if not bigframes.features.PANDAS_VERSIONS.is_arrow_list_dtype_usable:
460457
return
461-
if internal_schema != actual_schema:
458+
if internal_schema.to_bigquery() != actual_schema:
462459
raise ValueError(
463-
f"This error should only occur while testing. BigFrames internal schema: {internal_schema} does not match actual schema: {actual_schema}"
460+
f"This error should only occur while testing. BigFrames internal schema: {internal_schema.to_bigquery()} does not match actual schema: {actual_schema}"
464461
)
465-
if ibis_schema != actual_schema:
462+
if ibis_schema.to_bigquery() != actual_schema:
466463
raise ValueError(
467-
f"This error should only occur while testing. Ibis schema: {ibis_schema} does not match actual schema: {actual_schema}"
464+
f"This error should only occur while testing. Ibis schema: {ibis_schema.to_bigquery()} does not match actual schema: {actual_schema}"
468465
)
469466

470467
def to_arrow(

bigframes/core/schema.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,12 @@ def dtypes(self) -> typing.Tuple[bigframes.dtypes.Dtype, ...]:
5858
def _mapping(self) -> typing.Dict[ColumnIdentifierType, bigframes.dtypes.Dtype]:
5959
return {item.column: item.dtype for item in self.items}
6060

61+
def to_bigquery(self) -> typing.Tuple[google.cloud.bigquery.SchemaField, ...]:
62+
return tuple(
63+
bigframes.dtypes.convert_to_schema_field(item.column, item.dtype)
64+
for item in self.items
65+
)
66+
6167
def drop(self, columns: typing.Iterable[str]) -> ArraySchema:
6268
return ArraySchema(
6369
tuple(item for item in self.items if item.column not in columns)

bigframes/dtypes.py

Lines changed: 46 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,9 @@ class SimpleDtypeInfo:
7070

7171
dtype: Dtype
7272
arrow_dtype: typing.Optional[pa.DataType]
73-
type_kind: typing.Tuple[str, ...] # Should all correspond to the same db type
73+
type_kind: typing.Tuple[
74+
str, ...
75+
] # Should all correspond to the same db type. Put preferred canonical sql type name first
7476
logical_bytes: int = (
7577
8 # this is approximate only, some types are variably sized, also, compression
7678
)
@@ -84,20 +86,23 @@ class SimpleDtypeInfo:
8486
SimpleDtypeInfo(
8587
dtype=INT_DTYPE,
8688
arrow_dtype=pa.int64(),
87-
type_kind=("INT64", "INTEGER"),
89+
type_kind=("INTEGER", "INT64"),
8890
orderable=True,
8991
clusterable=True,
9092
),
9193
SimpleDtypeInfo(
9294
dtype=FLOAT_DTYPE,
9395
arrow_dtype=pa.float64(),
94-
type_kind=("FLOAT64", "FLOAT"),
96+
type_kind=("FLOAT", "FLOAT64"),
9597
orderable=True,
9698
),
9799
SimpleDtypeInfo(
98100
dtype=BOOL_DTYPE,
99101
arrow_dtype=pa.bool_(),
100-
type_kind=("BOOL", "BOOLEAN"),
102+
type_kind=(
103+
"BOOLEAN",
104+
"BOOL",
105+
),
101106
logical_bytes=1,
102107
orderable=True,
103108
clusterable=True,
@@ -143,15 +148,15 @@ class SimpleDtypeInfo:
143148
SimpleDtypeInfo(
144149
dtype=NUMERIC_DTYPE,
145150
arrow_dtype=pa.decimal128(38, 9),
146-
type_kind=("NUMERIC",),
151+
type_kind=("NUMERIC", "DECIMAL"),
147152
logical_bytes=16,
148153
orderable=True,
149154
clusterable=True,
150155
),
151156
SimpleDtypeInfo(
152157
dtype=BIGNUMERIC_DTYPE,
153158
arrow_dtype=pa.decimal256(76, 38),
154-
type_kind=("BIGNUMERIC",),
159+
type_kind=("BIGNUMERIC", "BIGDECIMAL"),
155160
logical_bytes=32,
156161
orderable=True,
157162
clusterable=True,
@@ -417,6 +422,7 @@ def infer_literal_arrow_type(literal) -> typing.Optional[pa.DataType]:
417422
for mapping in SIMPLE_TYPES
418423
for type_kind in mapping.type_kind
419424
}
425+
_BIGFRAMES_TO_TK = {mapping.dtype: mapping.type_kind[0] for mapping in SIMPLE_TYPES}
420426

421427

422428
def convert_schema_field(
@@ -440,12 +446,44 @@ def convert_schema_field(
440446
if is_repeated:
441447
pa_type = pa.list_(bigframes_dtype_to_arrow_dtype(singular_type))
442448
return field.name, pd.ArrowDtype(pa_type)
443-
else:
444-
return field.name, singular_type
449+
return field.name, singular_type
445450
else:
446451
raise ValueError(f"Cannot handle type: {field.field_type}")
447452

448453

454+
def convert_to_schema_field(
455+
name: str,
456+
bigframes_dtype: Dtype,
457+
) -> google.cloud.bigquery.SchemaField:
458+
if bigframes_dtype in _BIGFRAMES_TO_TK:
459+
return google.cloud.bigquery.SchemaField(
460+
name, _BIGFRAMES_TO_TK[bigframes_dtype]
461+
)
462+
if isinstance(bigframes_dtype, pd.ArrowDtype):
463+
if pa.types.is_list(bigframes_dtype.pyarrow_dtype):
464+
inner_type = arrow_dtype_to_bigframes_dtype(
465+
bigframes_dtype.pyarrow_dtype.value_type
466+
)
467+
inner_field = convert_to_schema_field(name, inner_type)
468+
return google.cloud.bigquery.SchemaField(
469+
name, inner_field.field_type, mode="REPEATED", fields=inner_field.fields
470+
)
471+
if pa.types.is_struct(bigframes_dtype.pyarrow_dtype):
472+
inner_fields: list[pa.Field] = []
473+
struct_type = typing.cast(pa.StructType, bigframes_dtype.pyarrow_dtype)
474+
for i in range(struct_type.num_fields):
475+
field = struct_type.field(i)
476+
inner_bf_type = arrow_dtype_to_bigframes_dtype(field.type)
477+
inner_fields.append(convert_to_schema_field(field.name, inner_bf_type))
478+
479+
return google.cloud.bigquery.SchemaField(
480+
name, "RECORD", fields=inner_fields
481+
)
482+
raise ValueError(
483+
f"No arrow conversion for {bigframes_dtype}. {constants.FEEDBACK_LINK}"
484+
)
485+
486+
449487
def bf_type_from_type_kind(
450488
bq_schema: list[google.cloud.bigquery.SchemaField],
451489
) -> typing.Dict[str, Dtype]:

0 commit comments

Comments
 (0)