|
19 | 19 | import threading |
20 | 20 | import unittest |
21 | 21 | import unittest.mock as mock |
22 | | -from datetime import date, datetime |
| 22 | +from datetime import date, datetime, timedelta |
23 | 23 | from pathlib import Path |
24 | 24 | from test import client_context |
25 | 25 | from test.utils import AllowListEventListener, NullsTestMixin |
@@ -1095,22 +1095,22 @@ def alltypes_sample(self, size=10000, seed=0, categorical=False): |
1095 | 1095 |
|
1096 | 1096 | np.random.seed(seed) |
1097 | 1097 | arrays = { |
1098 | | - "uint8": pa.array(np.arange(size, dtype=np.uint8), type=pa.int32()), |
1099 | | - "uint16": pa.array(np.arange(size, dtype=np.uint16), type=pa.int32()), |
1100 | | - "uint32": pa.array(np.arange(size, dtype=np.uint32), type=pa.int64()), |
1101 | | - "uint64": pa.array(np.arange(size, dtype=np.uint64), type=pa.int64()), |
1102 | | - "int8": pa.array(np.arange(size, dtype=np.int8), type=pa.int32()), |
1103 | | - "int16": pa.array(np.arange(size, dtype=np.int16), type=pa.int32()), |
| 1098 | + "uint8": np.arange(size, dtype=np.uint8), |
| 1099 | + "uint16": np.arange(size, dtype=np.uint16), |
| 1100 | + "uint32": np.arange(size, dtype=np.uint32), |
| 1101 | + "uint64": np.arange(size, dtype=np.uint64), |
| 1102 | + "int8": np.arange(size, dtype=np.int8), |
| 1103 | + "int16": np.arange(size, dtype=np.int16), |
1104 | 1104 | "int32": np.arange(size, dtype=np.int32), |
1105 | 1105 | "int64": np.arange(size, dtype=np.int64), |
1106 | | - "float16": pa.array(np.arange(size, dtype=np.float16), type=pa.float64()), |
1107 | | - "float32": pa.array(np.arange(size, dtype=np.float32), type=pa.float64()), |
1108 | | - "float64": pa.array(np.arange(size, dtype=np.float64), type=pa.float64()), |
| 1106 | + "float16": np.arange(size, dtype=np.float16), |
| 1107 | + "float32": np.arange(size, dtype=np.float32), |
| 1108 | + "float64": np.arange(size, dtype=np.float64), |
1109 | 1109 | "bool": np.random.randn(size) > 0, |
1110 | 1110 | "datetime_ms": np.arange("2016-01-01T00:00:00.001", size, dtype="datetime64[ms]"), |
1111 | 1111 | "datetime_us": np.arange("2016-01-01T00:00:00.000001", size, dtype="datetime64[us]"), |
1112 | 1112 | "datetime_ns": np.arange("2016-01-01T00:00:00.000000001", size, dtype="datetime64[ns]"), |
1113 | | - "timedelta": pa.array(np.arange(size, dtype="timedelta64[s]"), type=pa.int64()), |
| 1113 | + "timedelta": np.arange(size, dtype="timedelta64[s]"), |
1114 | 1114 | "str": pd.Series([str(x) for x in range(size)]), |
1115 | 1115 | "empty_str": [""] * size, |
1116 | 1116 | "str_with_nulls": [None] + [str(x) for x in range(size - 2)] + [None], |
@@ -1181,6 +1181,10 @@ def compare_arrow_mongodb_data(self, arrow_table, mongo_data): |
1181 | 1181 | assert ( |
1182 | 1182 | arrow_value == mongo_value |
1183 | 1183 | ), f"List mismatch in field '{column_name}' for row {row_idx}. Expected {arrow_value}, got {mongo_value}." |
| 1184 | + elif isinstance(arrow_value, timedelta): |
| 1185 | + assert ( |
| 1186 | + arrow_value == timedelta(seconds=mongo_value) |
| 1187 | + ), f"Timedelta mismatch in field '{column_name}' for row {row_idx}. Expected {arrow_value}, got {mongo_value}." |
1184 | 1188 | else: |
1185 | 1189 | assert ( |
1186 | 1190 | arrow_value == mongo_value |
|
0 commit comments