Skip to content

Commit e4807e2

Browse files
committed
Add tests for string arrays
1 parent bf8c9a5 commit e4807e2

File tree

2 files changed

+46
-1
lines changed

2 files changed

+46
-1
lines changed

pygmt/clib/conversion.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,9 @@ def _to_ndarray(array: Any) -> np.ndarray:
158158
"""
159159
# A dictionary mapping unsupported dtypes to the expected numpy dtype.
160160
dtypes: dict[str, type] = {
161+
# "string" for "string[python]", "string[pyarrow]", "string[pyarrow_numpy]", and
162+
# pa.string()
163+
"string": np.str_,
161164
"date32[day][pyarrow]": np.datetime64,
162165
"date64[ms][pyarrow]": np.datetime64,
163166
}
@@ -184,7 +187,7 @@ def _to_ndarray(array: Any) -> np.ndarray:
184187
if hasattr(array, "isna") and array.isna().any():
185188
array = array.astype(np.float64)
186189

187-
vec_dtype = str(getattr(array, "dtype", ""))
190+
vec_dtype = str(getattr(array, "dtype", getattr(array, "type", "")))
188191
array = np.ascontiguousarray(array, dtype=dtypes.get(vec_dtype))
189192
return array
190193

pygmt/tests/test_clib_to_ndarray.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,17 @@ def test_to_ndarray_numpy_ndarray_numpy_numeric(dtype):
7272
npt.assert_array_equal(result, array)
7373

7474

75+
@pytest.mark.parametrize("dtype", [None, np.str_])
76+
def test_to_ndarray_numpy_ndarray_numpy_string(dtype):
77+
"""
78+
Test the _to_ndarray function with 1-D NumPy arrays of strings.
79+
"""
80+
array = np.array(["a", "b", "c"], dtype=dtype)
81+
result = _to_ndarray(array)
82+
_check_result(result)
83+
npt.assert_array_equal(result, array)
84+
85+
7586
@pytest.mark.parametrize(
7687
"dtype",
7788
[
@@ -146,6 +157,26 @@ def test_to_ndarray_pandas_series_numeric_with_na(dtype):
146157
npt.assert_array_equal(result, np.array([1, np.nan, 3], dtype=np.float64))
147158

148159

160+
@pytest.mark.parametrize(
161+
"dtype",
162+
[
163+
# None,
164+
# np.str_,
165+
"string[python]",
166+
pytest.param("string[pyarrow]", marks=skip_if_no(package="pyarrow")),
167+
pytest.param("string[pyarrow_numpy]", marks=skip_if_no(package="pyarrow")),
168+
],
169+
)
170+
def test_to_ndarray_pandas_series_string(dtype):
171+
"""
172+
Test the _to_ndarray function with pandas Series with string dtype.
173+
"""
174+
series = pd.Series(["a", "bcd", "12345"], dtype=dtype)
175+
result = _to_ndarray(series)
176+
_check_result(result)
177+
npt.assert_array_equal(result, series)
178+
179+
149180
@pytest.mark.skipif(not _HAS_PYARROW, reason="pyarrow is not installed")
150181
@pytest.mark.parametrize(
151182
"dtype",
@@ -184,3 +215,14 @@ def test_to_ndarray_pyarrow_array_float16():
184215
result = _to_ndarray(array)
185216
_check_result(result)
186217
npt.assert_array_equal(result, array)
218+
219+
220+
@pytest.mark.skipif(not _HAS_PYARROW, reason="pyarrow is not installed")
221+
def test_to_ndarray_pyarrow_array_string():
222+
"""
223+
Test the _to_ndarray function with pyarrow string array.
224+
"""
225+
array = pa.array(["a", "bcd", "12345"], type=pa.string())
226+
result = _to_ndarray(array)
227+
_check_result(result)
228+
npt.assert_array_equal(result, array)

0 commit comments

Comments
 (0)