1
- from datetime import datetime
1
+ from datetime import (
2
+ datetime ,
3
+ timezone ,
4
+ )
2
5
3
6
import numpy as np
4
7
import pytest
@@ -416,27 +419,13 @@ def test_non_str_names_w_duplicates():
416
419
pd .api .interchange .from_dataframe (dfi , allow_copy = False )
417
420
418
421
419
- def test_nullable_integers () -> None :
420
- # https://github.com/pandas-dev/pandas/issues/55069
421
- df = pd .DataFrame ({"a" : [1 ]}, dtype = "Int8" )
422
- expected = pd .DataFrame ({"a" : [1 ]}, dtype = "int8" )
423
- result = pd .api .interchange .from_dataframe (df .__dataframe__ ())
424
- tm .assert_frame_equal (result , expected )
425
-
426
-
427
- def test_nullable_integers_pyarrow () -> None :
428
- # https://github.com/pandas-dev/pandas/issues/55069
429
- df = pd .DataFrame ({"a" : [1 ]}, dtype = "Int8[pyarrow]" )
430
- expected = pd .DataFrame ({"a" : [1 ]}, dtype = "int8" )
431
- result = pd .api .interchange .from_dataframe (df .__dataframe__ ())
432
- tm .assert_frame_equal (result , expected )
433
-
434
-
435
422
@pytest .mark .parametrize (
436
423
("data" , "dtype" , "expected_dtype" ),
437
424
[
438
425
([1 , 2 , None ], "Int64" , "int64" ),
439
426
([1 , 2 , None ], "Int64[pyarrow]" , "int64" ),
427
+ ([1 , 2 , None ], "Int8" , "int8" ),
428
+ ([1 , 2 , None ], "Int8[pyarrow]" , "int8" ),
440
429
(
441
430
[1 , 2 , None ],
442
431
"UInt64" ,
@@ -451,15 +440,39 @@ def test_nullable_integers_pyarrow() -> None:
451
440
([1.0 , 2.25 , None ], "Float32[pyarrow]" , "float32" ),
452
441
([True , False , None ], "boolean" , "bool" ),
453
442
([True , False , None ], "boolean[pyarrow]" , "bool" ),
443
+ (["much ado" , "about" , None ], "string[pyarrow_numpy]" , "large_string" ),
444
+ (["much ado" , "about" , None ], "string[pyarrow]" , "large_string" ),
445
+ (
446
+ [datetime (2020 , 1 , 1 ), datetime (2020 , 1 , 2 ), None ],
447
+ "timestamp[ns][pyarrow]" ,
448
+ "timestamp[ns]" ,
449
+ ),
450
+ (
451
+ [datetime (2020 , 1 , 1 ), datetime (2020 , 1 , 2 ), None ],
452
+ "timestamp[us][pyarrow]" ,
453
+ "timestamp[us]" ,
454
+ ),
455
+ (
456
+ [
457
+ datetime (2020 , 1 , 1 , tzinfo = timezone .utc ),
458
+ datetime (2020 , 1 , 2 , tzinfo = timezone .utc ),
459
+ None ,
460
+ ],
461
+ "timestamp[us, Asia/Kathmandu][pyarrow]" ,
462
+ "timestamp[us, tz=Asia/Kathmandu]" ,
463
+ ),
454
464
],
455
465
)
456
466
def test_pandas_nullable_w_missing_values (
457
467
data : list , dtype : str , expected_dtype : str
458
468
) -> None :
459
469
# https://github.com/pandas-dev/pandas/issues/57643
460
- pytest .importorskip ("pyarrow" , "11.0.0" )
470
+ pa = pytest .importorskip ("pyarrow" , "11.0.0" )
461
471
import pyarrow .interchange as pai
462
472
473
+ if expected_dtype == "timestamp[us, tz=Asia/Kathmandu]" :
474
+ expected_dtype = pa .timestamp ("us" , "Asia/Kathmandu" )
475
+
463
476
df = pd .DataFrame ({"a" : data }, dtype = dtype )
464
477
result = pai .from_dataframe (df .__dataframe__ ())["a" ]
465
478
assert result .type == expected_dtype
0 commit comments