|
9 | 9 | from pandas.compat import HAS_PYARROW |
10 | 10 |
|
11 | 11 | from pandas import ( |
| 12 | + ArrowDtype, |
12 | 13 | DataFrame, |
13 | 14 | Index, |
14 | 15 | Series, |
@@ -539,18 +540,38 @@ def test_int_dtype_different_index_not_bool(self): |
539 | 540 | result = ser1 ^ ser2 |
540 | 541 | tm.assert_series_equal(result, expected) |
541 | 542 |
|
| 543 | + # TODO: this belongs in comparison tests |
542 | 544 | def test_pyarrow_numpy_string_invalid(self): |
543 | 545 | # GH#56008 |
544 | | - pytest.importorskip("pyarrow") |
| 546 | + pa = pytest.importorskip("pyarrow") |
545 | 547 | ser = Series([False, True]) |
546 | 548 | ser2 = Series(["a", "b"], dtype="string[pyarrow_numpy]") |
547 | 549 | result = ser == ser2 |
548 | | - expected = Series(False, index=ser.index) |
549 | | - tm.assert_series_equal(result, expected) |
| 550 | + expected_eq = Series(False, index=ser.index) |
| 551 | + tm.assert_series_equal(result, expected_eq) |
550 | 552 |
|
551 | 553 | result = ser != ser2 |
552 | | - expected = Series(True, index=ser.index) |
553 | | - tm.assert_series_equal(result, expected) |
| 554 | + expected_ne = Series(True, index=ser.index) |
| 555 | + tm.assert_series_equal(result, expected_ne) |
554 | 556 |
|
555 | 557 | with pytest.raises(TypeError, match="Invalid comparison"): |
556 | 558 | ser > ser2 |
| 559 | + |
| 560 | + # GH#59505 |
| 561 | + ser3 = ser2.astype("string[pyarrow]") |
| 562 | + result3_eq = ser3 == ser |
| 563 | + tm.assert_series_equal(result3_eq, expected_eq.astype("bool[pyarrow]")) |
| 564 | + result3_ne = ser3 != ser |
| 565 | + tm.assert_series_equal(result3_ne, expected_ne.astype("bool[pyarrow]")) |
| 566 | + |
| 567 | + with pytest.raises(TypeError, match="Invalid comparison"): |
| 568 | + ser > ser3 |
| 569 | + |
| 570 | + ser4 = ser2.astype(ArrowDtype(pa.string())) |
| 571 | + result4_eq = ser4 == ser |
| 572 | + tm.assert_series_equal(result4_eq, expected_eq.astype("bool[pyarrow]")) |
| 573 | + result4_ne = ser4 != ser |
| 574 | + tm.assert_series_equal(result4_ne, expected_ne.astype("bool[pyarrow]")) |
| 575 | + |
| 576 | + with pytest.raises(TypeError, match="Invalid comparison"): |
| 577 | + ser > ser4 |
0 commit comments