Skip to content

Commit b4d17ff

Browse files
feat: add astype(type, errors='null') to cast safely (#1122)
1 parent 9752da1 commit b4d17ff

File tree

11 files changed

+96
-17
lines changed

11 files changed

+96
-17
lines changed

bigframes/core/compile/ibis_types.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@
8989

9090

9191
def cast_ibis_value(
92-
value: ibis_types.Value, to_type: ibis_dtypes.DataType
92+
value: ibis_types.Value, to_type: ibis_dtypes.DataType, safe: bool = False
9393
) -> ibis_types.Value:
9494
"""Perform compatible type casts of ibis values
9595
@@ -176,7 +176,7 @@ def cast_ibis_value(
176176
value = ibis_value_to_canonical_type(value)
177177
if value.type() in good_casts:
178178
if to_type in good_casts[value.type()]:
179-
return value.cast(to_type)
179+
return value.try_cast(to_type) if safe else value.cast(to_type)
180180
else:
181181
# this should never happen
182182
raise TypeError(
@@ -188,10 +188,16 @@ def cast_ibis_value(
188188
# BigQuery casts bools to lower case strings. Capitalize the result to match Pandas
189189
# TODO(bmil): remove this workaround after fixing Ibis
190190
if value.type() == ibis_dtypes.bool and to_type == ibis_dtypes.string:
191-
return cast(ibis_types.StringValue, value.cast(to_type)).capitalize()
191+
if safe:
192+
return cast(ibis_types.StringValue, value.try_cast(to_type)).capitalize()
193+
else:
194+
return cast(ibis_types.StringValue, value.cast(to_type)).capitalize()
192195

193196
if value.type() == ibis_dtypes.bool and to_type == ibis_dtypes.float64:
194-
return value.cast(ibis_dtypes.int64).cast(ibis_dtypes.float64)
197+
if safe:
198+
return value.try_cast(ibis_dtypes.int64).try_cast(ibis_dtypes.float64)
199+
else:
200+
return value.cast(ibis_dtypes.int64).cast(ibis_dtypes.float64)
195201

196202
if value.type() == ibis_dtypes.float64 and to_type == ibis_dtypes.bool:
197203
return value != ibis_types.literal(0)

bigframes/core/compile/scalar_op_compiler.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -947,7 +947,9 @@ def struct_field_op_impl(x: ibis_types.Value, op: ops.StructFieldOp):
947947
return result.cast(result.type()(nullable=True)).name(name)
948948

949949

950-
def numeric_to_datetime(x: ibis_types.Value, unit: str) -> ibis_types.TimestampValue:
950+
def numeric_to_datetime(
951+
x: ibis_types.Value, unit: str, safe: bool = False
952+
) -> ibis_types.TimestampValue:
951953
if not isinstance(x, ibis_types.IntegerValue) and not isinstance(
952954
x, ibis_types.FloatingValue
953955
):
@@ -956,7 +958,11 @@ def numeric_to_datetime(x: ibis_types.Value, unit: str) -> ibis_types.TimestampV
956958
if unit not in UNIT_TO_US_CONVERSION_FACTORS:
957959
raise ValueError(f"Cannot convert input with unit '{unit}'.")
958960
x_converted = x * UNIT_TO_US_CONVERSION_FACTORS[unit]
959-
x_converted = x_converted.cast(ibis_dtypes.int64)
961+
x_converted = (
962+
x_converted.try_cast(ibis_dtypes.int64)
963+
if safe
964+
else x_converted.cast(ibis_dtypes.int64)
965+
)
960966

961967
# Note: Due to an issue where casting directly to a timestamp
962968
# without a timezone does not work, we first cast to UTC. This
@@ -978,8 +984,11 @@ def astype_op_impl(x: ibis_types.Value, op: ops.AsTypeOp):
978984

979985
# When casting DATETIME column into INT column, we need to convert the column into TIMESTAMP first.
980986
if to_type == ibis_dtypes.int64 and x.type() == ibis_dtypes.timestamp:
981-
x_converted = x.cast(ibis_dtypes.Timestamp(timezone="UTC"))
982-
return bigframes.core.compile.ibis_types.cast_ibis_value(x_converted, to_type)
987+
utc_time_type = ibis_dtypes.Timestamp(timezone="UTC")
988+
x_converted = x.try_cast(utc_time_type) if op.safe else x.cast(utc_time_type)
989+
return bigframes.core.compile.ibis_types.cast_ibis_value(
990+
x_converted, to_type, safe=op.safe
991+
)
983992

984993
if to_type == ibis_dtypes.int64 and x.type() == ibis_dtypes.time:
985994
# The conversion unit is set to "us" (microseconds) for consistency
@@ -991,15 +1000,20 @@ def astype_op_impl(x: ibis_types.Value, op: ops.AsTypeOp):
9911000
# with pandas converting int64[pyarrow] to timestamp[us][pyarrow],
9921001
# timestamp[us, tz=UTC][pyarrow], and time64[us][pyarrow].
9931002
unit = "us"
994-
x_converted = numeric_to_datetime(x, unit)
1003+
x_converted = numeric_to_datetime(x, unit, safe=op.safe)
9951004
if to_type == ibis_dtypes.timestamp:
996-
return x_converted.cast(ibis_dtypes.Timestamp())
1005+
return (
1006+
x_converted.try_cast(ibis_dtypes.Timestamp())
1007+
if op.safe
1008+
else x_converted.cast(ibis_dtypes.Timestamp())
1009+
)
9971010
elif to_type == ibis_dtypes.Timestamp(timezone="UTC"):
9981011
return x_converted
9991012
elif to_type == ibis_dtypes.time:
10001013
return x_converted.time()
10011014

1002-
return bigframes.core.compile.ibis_types.cast_ibis_value(x, to_type)
1015+
# TODO: either inline this function, or push rest of this op into the function
1016+
return bigframes.core.compile.ibis_types.cast_ibis_value(x, to_type, safe=op.safe)
10031017

10041018

10051019
@scalar_op_compiler.register_unary_op(ops.IsInOp, pass_op=True)

bigframes/core/indexes/base.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from __future__ import annotations
1818

1919
import typing
20-
from typing import Hashable, Optional, Sequence, Union
20+
from typing import Hashable, Literal, Optional, Sequence, Union
2121

2222
import bigframes_vendored.constants as constants
2323
import bigframes_vendored.pandas.core.indexes.base as vendored_pandas_index
@@ -324,11 +324,17 @@ def sort_values(self, *, ascending: bool = True, na_position: str = "last"):
324324
def astype(
325325
self,
326326
dtype: Union[bigframes.dtypes.DtypeString, bigframes.dtypes.Dtype],
327+
*,
328+
errors: Literal["raise", "null"] = "raise",
327329
) -> Index:
330+
if errors not in ["raise", "null"]:
331+
raise ValueError("Arg 'error' must be one of 'raise' or 'null'")
328332
if self.nlevels > 1:
329333
raise TypeError("Multiindex does not support 'astype'")
330334
return self._apply_unary_expr(
331-
ops.AsTypeOp(to_type=dtype).as_expr(ex.free_var("arg"))
335+
ops.AsTypeOp(to_type=dtype, safe=(errors == "null")).as_expr(
336+
ex.free_var("arg")
337+
)
332338
)
333339

334340
def all(self) -> bool:

bigframes/dataframe.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -365,8 +365,14 @@ def __iter__(self):
365365
def astype(
366366
self,
367367
dtype: Union[bigframes.dtypes.DtypeString, bigframes.dtypes.Dtype],
368+
*,
369+
errors: Literal["raise", "null"] = "raise",
368370
) -> DataFrame:
369-
return self._apply_unary_op(ops.AsTypeOp(to_type=dtype))
371+
if errors not in ["raise", "null"]:
372+
raise ValueError("Arg 'error' must be one of 'raise' or 'null'")
373+
return self._apply_unary_op(
374+
ops.AsTypeOp(to_type=dtype, safe=(errors == "null"))
375+
)
370376

371377
def _to_sql_query(
372378
self, include_index: bool, enable_cache: bool = True

bigframes/operations/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -494,6 +494,7 @@ class AsTypeOp(UnaryOp):
494494
name: typing.ClassVar[str] = "astype"
495495
# TODO: Convert strings to dtype earlier
496496
to_type: dtypes.DtypeString | dtypes.Dtype
497+
safe: bool = False
497498

498499
def output_type(self, *input_types):
499500
# TODO: We should do this conversion earlier

bigframes/series.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -352,8 +352,14 @@ def __repr__(self) -> str:
352352
def astype(
353353
self,
354354
dtype: Union[bigframes.dtypes.DtypeString, bigframes.dtypes.Dtype],
355+
*,
356+
errors: Literal["raise", "null"] = "raise",
355357
) -> Series:
356-
return self._apply_unary_op(bigframes.operations.AsTypeOp(to_type=dtype))
358+
if errors not in ["raise", "null"]:
359+
raise ValueError("Arg 'error' must be one of 'raise' or 'null'")
360+
return self._apply_unary_op(
361+
bigframes.operations.AsTypeOp(to_type=dtype, safe=(errors == "null"))
362+
)
357363

358364
def to_pandas(
359365
self,

tests/system/small/test_dataframe.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3687,6 +3687,12 @@ def test_df_add_suffix(scalars_df_index, scalars_pandas_df_index, axis):
36873687
)
36883688

36893689

3690+
def test_df_astype_error_error(session):
3691+
input = pd.DataFrame(["hello", "world", "3.11", "4000"])
3692+
with pytest.raises(ValueError):
3693+
session.read_pandas(input).astype("Float64", errors="bad_value")
3694+
3695+
36903696
def test_df_columns_filter_items(scalars_df_index, scalars_pandas_df_index):
36913697
if pd.__version__.startswith("2.0") or pd.__version__.startswith("1."):
36923698
pytest.skip("pandas filter items behavior different pre-2.1")

tests/system/small/test_index.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,12 @@ def test_index_astype(scalars_df_index, scalars_pandas_df_index):
123123
pd.testing.assert_index_equal(bf_result, pd_result)
124124

125125

126+
def test_index_astype_error_error(session):
127+
input = pd.Index(["hello", "world", "3.11", "4000"])
128+
with pytest.raises(ValueError):
129+
session.read_pandas(input).astype("Float64", errors="bad_value")
130+
131+
126132
def test_index_any(scalars_df_index, scalars_pandas_df_index):
127133
bf_result = scalars_df_index.set_index("int64_col").index.any()
128134
pd_result = scalars_pandas_df_index.set_index("int64_col").index.any()

tests/system/small/test_series.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3092,6 +3092,7 @@ def foo(x):
30923092
assert_series_equal(bf_result, pd_result, check_dtype=False)
30933093

30943094

3095+
@pytest.mark.parametrize("errors", ["raise", "null"])
30953096
@pytest.mark.parametrize(
30963097
("column", "to_type"),
30973098
[
@@ -3107,6 +3108,7 @@ def foo(x):
31073108
("int64_col", "time64[us][pyarrow]"),
31083109
("bool_col", "Int64"),
31093110
("bool_col", "string[pyarrow]"),
3111+
("bool_col", "Float64"),
31103112
("string_col", "binary[pyarrow]"),
31113113
("bytes_col", "string[pyarrow]"),
31123114
# pandas actually doesn't let folks convert to/from naive timestamp and
@@ -3142,12 +3144,29 @@ def foo(x):
31423144
],
31433145
)
31443146
@skip_legacy_pandas
3145-
def test_astype(scalars_df_index, scalars_pandas_df_index, column, to_type):
3146-
bf_result = scalars_df_index[column].astype(to_type).to_pandas()
3147+
def test_astype(scalars_df_index, scalars_pandas_df_index, column, to_type, errors):
3148+
bf_result = scalars_df_index[column].astype(to_type, errors=errors).to_pandas()
31473149
pd_result = scalars_pandas_df_index[column].astype(to_type)
31483150
pd.testing.assert_series_equal(bf_result, pd_result)
31493151

31503152

3153+
def test_astype_safe(session):
3154+
input = pd.Series(["hello", "world", "3.11", "4000"])
3155+
exepcted = pd.Series(
3156+
[None, None, 3.11, 4000],
3157+
dtype="Float64",
3158+
index=pd.Index([0, 1, 2, 3], dtype="Int64"),
3159+
)
3160+
result = session.read_pandas(input).astype("Float64", errors="null").to_pandas()
3161+
pd.testing.assert_series_equal(result, exepcted)
3162+
3163+
3164+
def test_series_astype_error_error(session):
3165+
input = pd.Series(["hello", "world", "3.11", "4000"])
3166+
with pytest.raises(ValueError):
3167+
session.read_pandas(input).astype("Float64", errors="bad_value")
3168+
3169+
31513170
@skip_legacy_pandas
31523171
def test_astype_numeric_to_int(scalars_df_index, scalars_pandas_df_index):
31533172
column = "numeric_col"

third_party/bigframes_vendored/pandas/core/generic.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,10 @@ def astype(self, dtype):
180180
``pd.ArrowDtype(pa.time64("us"))``,
181181
``pd.ArrowDtype(pa.timestamp("us"))``,
182182
``pd.ArrowDtype(pa.timestamp("us", tz="UTC"))``.
183+
errors ({'raise', 'null'}, default 'raise'):
184+
Control raising of exceptions on invalid data for provided dtype.
185+
If 'raise', allow exceptions to be raised if any value fails cast
186+
If 'null', will assign null value if value fails cast
183187
184188
Returns:
185189
bigframes.pandas.DataFrame:

0 commit comments

Comments
 (0)