|
10 | 10 | ) |
11 | 11 |
|
12 | 12 | import numpy as np |
13 | | -import pyarrow as pa |
14 | 13 |
|
15 | 14 | from pandas._libs import missing as libmissing |
16 | 15 |
|
@@ -77,6 +76,8 @@ def string( |
77 | 76 | if mode not in valid_modes: |
78 | 77 | raise ValueError(f"mode must be one of {valid_modes}, got {mode}") |
79 | 78 | if backend == "pyarrow": |
| 79 | + import pyarrow as pa |
| 80 | + |
80 | 81 | if mode == "string": |
81 | 82 | pa_type = pa.large_string() if large else pa.string() |
82 | 83 | else: # mode == "binary" |
@@ -128,6 +129,8 @@ def datetime( |
128 | 129 | return DatetimeTZDtype(unit=unit, tz=tz) |
129 | 130 | return np.dtype(f"datetime64[{unit}]") |
130 | 131 | else: # pyarrow |
| 132 | + import pyarrow as pa |
| 133 | + |
131 | 134 | return ArrowDtype(pa.timestamp(unit, tz=tz)) |
132 | 135 |
|
133 | 136 |
|
@@ -167,24 +170,25 @@ def integer( |
167 | 170 |
|
168 | 171 | if backend == "numpy": |
169 | 172 | return np.dtype(f"int{bits}") |
170 | | - |
171 | | - if backend == "pandas": |
| 173 | + elif backend == "pandas": |
172 | 174 | if bits == 8: |
173 | 175 | return Int8Dtype() |
174 | 176 | elif bits == 16: |
175 | 177 | return Int16Dtype() |
176 | 178 | elif bits == 32: |
177 | 179 | return Int32Dtype() |
178 | | - elif bits == 64: |
| 180 | + else: # bits == 64 |
179 | 181 | return Int64Dtype() |
180 | 182 | elif backend == "pyarrow": |
| 183 | + import pyarrow as pa |
| 184 | + |
181 | 185 | if bits == 8: |
182 | 186 | return ArrowDtype(pa.int8()) |
183 | 187 | elif bits == 16: |
184 | 188 | return ArrowDtype(pa.int16()) |
185 | 189 | elif bits == 32: |
186 | 190 | return ArrowDtype(pa.int32()) |
187 | | - elif bits == 64: |
| 191 | + else: # bits == 64 |
188 | 192 | return ArrowDtype(pa.int64()) |
189 | 193 | else: |
190 | 194 | raise ValueError(f"Unsupported backend: {backend!r}") |
@@ -224,16 +228,17 @@ def floating( |
224 | 228 |
|
225 | 229 | if backend == "numpy": |
226 | 230 | return np.dtype(f"float{bits}") |
227 | | - |
228 | | - if backend == "pandas": |
| 231 | + elif backend == "pandas": |
229 | 232 | if bits == 32: |
230 | 233 | return Float32Dtype() |
231 | | - elif bits == 64: |
| 234 | + else: # bits == 64 |
232 | 235 | return Float64Dtype() |
233 | 236 | elif backend == "pyarrow": |
| 237 | + import pyarrow as pa |
| 238 | + |
234 | 239 | if bits == 32: |
235 | 240 | return ArrowDtype(pa.float32()) |
236 | | - elif bits == 64: |
| 241 | + else: # bits == 64 |
237 | 242 | return ArrowDtype(pa.float64()) |
238 | 243 | else: |
239 | 244 | raise ValueError(f"Unsupported backend: {backend!r}") |
@@ -270,6 +275,8 @@ def decimal( |
270 | 275 | decimal256[40, 5][pyarrow] |
271 | 276 | """ |
272 | 277 | if backend == "pyarrow": |
| 278 | + import pyarrow as pa |
| 279 | + |
273 | 280 | if precision <= 38: |
274 | 281 | return ArrowDtype(pa.decimal128(precision, scale)) |
275 | 282 | return ArrowDtype(pa.decimal256(precision, scale)) |
@@ -302,6 +309,8 @@ def boolean( |
302 | 309 | if backend == "numpy": |
303 | 310 | return BooleanDtype() |
304 | 311 | else: # pyarrow |
| 312 | + import pyarrow as pa |
| 313 | + |
305 | 314 | return ArrowDtype(pa.bool_()) |
306 | 315 |
|
307 | 316 |
|
@@ -344,6 +353,8 @@ def list( |
344 | 353 | if backend == "numpy": |
345 | 354 | return np.dtype("object") |
346 | 355 | else: # pyarrow |
| 356 | + import pyarrow as pa |
| 357 | + |
347 | 358 | if value_type is None: |
348 | 359 | value_type = pa.int64() |
349 | 360 | pa_type = pa.large_list(value_type) if large else pa.list_(value_type) |
@@ -396,6 +407,8 @@ def categorical( |
396 | 407 | if backend == "numpy": |
397 | 408 | return CategoricalDtype(categories=categories, ordered=ordered) |
398 | 409 | else: # pyarrow |
| 410 | + import pyarrow as pa |
| 411 | + |
399 | 412 | index_type = pa.int32() if index_type is None else index_type |
400 | 413 | value_type = pa.string() if value_type is None else value_type |
401 | 414 | return ArrowDtype(pa.dictionary(index_type, value_type)) |
@@ -437,6 +450,8 @@ def interval( |
437 | 450 | if backend == "numpy": |
438 | 451 | return IntervalDtype(subtype=subtype, closed=closed) |
439 | 452 | else: # pyarrow |
| 453 | + import pyarrow as pa |
| 454 | + |
440 | 455 | if subtype is not None: |
441 | 456 | return ArrowDtype( |
442 | 457 | pa.struct( |
@@ -491,6 +506,8 @@ def period( |
491 | 506 | if backend == "numpy": |
492 | 507 | return PeriodDtype(freq=freq) |
493 | 508 | else: # pyarrow |
| 509 | + import pyarrow as pa |
| 510 | + |
494 | 511 | return ArrowDtype(pa.month_day_nano_interval()) |
495 | 512 |
|
496 | 513 |
|
@@ -590,6 +607,8 @@ def date( |
590 | 607 |
|
591 | 608 | if backend != "pyarrow": |
592 | 609 | raise ValueError("Date types are only supported with PyArrow backend.") |
| 610 | + import pyarrow as pa |
| 611 | + |
593 | 612 | return ArrowDtype(pa.date32() if unit == "day" else pa.date64()) |
594 | 613 |
|
595 | 614 |
|
@@ -629,6 +648,8 @@ def duration( |
629 | 648 | if backend == "numpy": |
630 | 649 | return np.dtype(f"timedelta64[{unit}]") |
631 | 650 | else: # pyarrow |
| 651 | + import pyarrow as pa |
| 652 | + |
632 | 653 | return ArrowDtype(pa.duration(unit)) |
633 | 654 |
|
634 | 655 |
|
@@ -677,6 +698,8 @@ def map( |
677 | 698 | """ |
678 | 699 | if backend != "pyarrow": |
679 | 700 | raise ValueError("Map types are only supported with PyArrow backend.") |
| 701 | + import pyarrow as pa |
| 702 | + |
680 | 703 | return ArrowDtype(pa.map_(index_type, value_type)) |
681 | 704 |
|
682 | 705 |
|
@@ -724,14 +747,10 @@ def struct( |
724 | 747 | 1 (2, Bob) |
725 | 748 | dtype: struct<id: int32, name: string>[pyarrow] |
726 | 749 | """ |
727 | | - if backend != "pyarrow": |
728 | | - raise ValueError("Struct types are only supported with PyArrow backend.") |
729 | | - # Validate that fields is a list of (str, type) tuples |
730 | | - for field in fields: |
731 | | - if ( |
732 | | - not isinstance(field, tuple) |
733 | | - or len(field) != 2 |
734 | | - or not isinstance(field[0], str) |
735 | | - ): |
736 | | - raise ValueError("Each field must be a tuple of (str, type), got {field}") |
737 | | - return ArrowDtype(pa.struct(fields)) |
| 750 | + if backend == "pyarrow": |
| 751 | + import pyarrow as pa |
| 752 | + |
| 753 | + pa_fields = [(name, getattr(typ, "pyarrow_dtype", typ)) for name, typ in fields] |
| 754 | + return ArrowDtype(pa.struct(pa_fields)) |
| 755 | + else: |
| 756 | + raise ValueError(f"Unsupported backend: {backend!r}") |
0 commit comments