Skip to content

Commit 81e8fb8

Browse files
authored
fix: use value equality to check types for unix epoch functions and timestamp diff (#1690)
1 parent 340b93d commit 81e8fb8

File tree

3 files changed

+75
-4
lines changed

3 files changed

+75
-4
lines changed

bigframes/operations/datetime_ops.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ class UnixSeconds(base_ops.UnaryOp):
8484
name: typing.ClassVar[str] = "unix_seconds"
8585

8686
def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType:
87-
if input_types[0] is not dtypes.TIMESTAMP_DTYPE:
87+
if input_types[0] != dtypes.TIMESTAMP_DTYPE:
8888
raise TypeError("expected timestamp input")
8989
return dtypes.INT_DTYPE
9090

@@ -94,7 +94,7 @@ class UnixMillis(base_ops.UnaryOp):
9494
name: typing.ClassVar[str] = "unix_millis"
9595

9696
def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType:
97-
if input_types[0] is not dtypes.TIMESTAMP_DTYPE:
97+
if input_types[0] != dtypes.TIMESTAMP_DTYPE:
9898
raise TypeError("expected timestamp input")
9999
return dtypes.INT_DTYPE
100100

@@ -104,7 +104,7 @@ class UnixMicros(base_ops.UnaryOp):
104104
name: typing.ClassVar[str] = "unix_micros"
105105

106106
def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType:
107-
if input_types[0] is not dtypes.TIMESTAMP_DTYPE:
107+
if input_types[0] != dtypes.TIMESTAMP_DTYPE:
108108
raise TypeError("expected timestamp input")
109109
return dtypes.INT_DTYPE
110110

@@ -114,7 +114,7 @@ class TimestampDiff(base_ops.BinaryOp):
114114
name: typing.ClassVar[str] = "timestamp_diff"
115115

116116
def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType:
117-
if input_types[0] is not input_types[1]:
117+
if input_types[0] != input_types[1]:
118118
raise TypeError(
119119
f"two inputs have different types. left: {input_types[0]}, right: {input_types[1]}"
120120
)

tests/system/small/bigquery/test_datetime.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,20 @@
1515
import typing
1616

1717
import pandas as pd
18+
import pyarrow as pa
1819
import pytest
1920

2021
from bigframes import bigquery
2122

23+
_TIMESTAMP_DTYPE = pd.ArrowDtype(pa.timestamp("us", tz="UTC"))
24+
25+
26+
@pytest.fixture
27+
def int_series(session):
28+
pd_series = pd.Series([1, 2, 3, 4, 5])
29+
30+
return session.read_pandas(pd_series), pd_series
31+
2232

2333
def test_unix_seconds(scalars_dfs):
2434
bigframes_df, pandas_df = scalars_dfs
@@ -33,6 +43,19 @@ def test_unix_seconds(scalars_dfs):
3343
pd.testing.assert_series_equal(actual_res, expected_res)
3444

3545

46+
def test_unix_seconds_after_type_casting(int_series):
47+
bf_series, pd_series = int_series
48+
49+
actual_res = bigquery.unix_seconds(bf_series.astype(_TIMESTAMP_DTYPE)).to_pandas()
50+
51+
expected_res = (
52+
pd_series.astype(_TIMESTAMP_DTYPE)
53+
.apply(lambda ts: _to_unix_epoch(ts, "s"))
54+
.astype("Int64")
55+
)
56+
pd.testing.assert_series_equal(actual_res, expected_res, check_index_type=False)
57+
58+
3659
def test_unix_seconds_incorrect_input_type_raise_error(scalars_dfs):
3760
df, _ = scalars_dfs
3861

@@ -53,6 +76,19 @@ def test_unix_millis(scalars_dfs):
5376
pd.testing.assert_series_equal(actual_res, expected_res)
5477

5578

79+
def test_unix_millis_after_type_casting(int_series):
80+
bf_series, pd_series = int_series
81+
82+
actual_res = bigquery.unix_millis(bf_series.astype(_TIMESTAMP_DTYPE)).to_pandas()
83+
84+
expected_res = (
85+
pd_series.astype(_TIMESTAMP_DTYPE)
86+
.apply(lambda ts: _to_unix_epoch(ts, "ms"))
87+
.astype("Int64")
88+
)
89+
pd.testing.assert_series_equal(actual_res, expected_res, check_index_type=False)
90+
91+
5692
def test_unix_millis_incorrect_input_type_raise_error(scalars_dfs):
5793
df, _ = scalars_dfs
5894

@@ -73,6 +109,19 @@ def test_unix_micros(scalars_dfs):
73109
pd.testing.assert_series_equal(actual_res, expected_res)
74110

75111

112+
def test_unix_micros_after_type_casting(int_series):
113+
bf_series, pd_series = int_series
114+
115+
actual_res = bigquery.unix_micros(bf_series.astype(_TIMESTAMP_DTYPE)).to_pandas()
116+
117+
expected_res = (
118+
pd_series.astype(_TIMESTAMP_DTYPE)
119+
.apply(lambda ts: _to_unix_epoch(ts, "us"))
120+
.astype("Int64")
121+
)
122+
pd.testing.assert_series_equal(actual_res, expected_res, check_index_type=False)
123+
124+
76125
def test_unix_micros_incorrect_input_type_raise_error(scalars_dfs):
77126
df, _ = scalars_dfs
78127

tests/system/small/operations/test_timedeltas.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ def temporal_dfs(session):
6060
],
6161
"float_col": [1.5, 2, -3],
6262
"int_col": [1, 2, -3],
63+
"positive_int_col": [1, 2, 3],
6364
}
6465
)
6566

@@ -607,3 +608,24 @@ def test_timedelta_agg__int_result(temporal_dfs, agg_func):
607608

608609
expected_result = agg_func(pd_df["timedelta_col_1"])
609610
assert actual_result == expected_result
611+
612+
613+
def test_timestamp_diff_after_type_casting(temporal_dfs):
614+
if version.Version(pd.__version__) <= version.Version("2.1.0"):
615+
pytest.skip(
616+
"Temporal type casting is not well-supported in older verions of Pandas."
617+
)
618+
619+
bf_df, pd_df = temporal_dfs
620+
dtype = pd.ArrowDtype(pa.timestamp("us", tz="UTC"))
621+
622+
actual_result = (
623+
bf_df["timestamp_col"] - bf_df["positive_int_col"].astype(dtype)
624+
).to_pandas()
625+
626+
expected_result = pd_df["timestamp_col"] - pd_df["positive_int_col"].astype(
627+
"datetime64[us, UTC]"
628+
)
629+
pandas.testing.assert_series_equal(
630+
actual_result, expected_result, check_index_type=False, check_dtype=False
631+
)

0 commit comments

Comments
 (0)