Skip to content

Commit 1746c04

Browse files
authored
Session.virtualfile_to_dataset: Add 'header' parameter to parse column names from table header (#3117)
1 parent fd286fb commit 1746c04

File tree

3 files changed

+84
-8
lines changed

3 files changed

+84
-8
lines changed

pygmt/clib/session.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1810,6 +1810,7 @@ def virtualfile_to_dataset(
18101810
self,
18111811
vfname: str,
18121812
output_type: Literal["pandas", "numpy", "file", "strings"] = "pandas",
1813+
header: int | None = None,
18131814
column_names: list[str] | None = None,
18141815
dtype: type | dict[str, type] | None = None,
18151816
index_col: str | int | None = None,
@@ -1831,6 +1832,10 @@ def virtualfile_to_dataset(
18311832
- ``"numpy"`` will return a :class:`numpy.ndarray` object.
18321833
- ``"file"`` means the result was saved to a file and will return ``None``.
18331834
- ``"strings"`` will return the trailing text only as an array of strings.
1835+
header
1836+
Row number containing column names for the :class:`pandas.DataFrame` output.
1837+
``header=None`` means not to parse the column names from table header.
1838+
Ignored if the row number is larger than the number of headers in the table.
18341839
column_names
18351840
The column names for the :class:`pandas.DataFrame` output.
18361841
dtype
@@ -1945,7 +1950,7 @@ def virtualfile_to_dataset(
19451950
return result.to_strings()
19461951

19471952
result = result.to_dataframe(
1948-
column_names=column_names, dtype=dtype, index_col=index_col
1953+
header=header, column_names=column_names, dtype=dtype, index_col=index_col
19491954
)
19501955
if output_type == "numpy": # numpy.ndarray output
19511956
return result.to_numpy()

pygmt/datatypes/dataset.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ class _GMT_DATASET(ctp.Structure): # noqa: N801
2727
>>> with GMTTempFile(suffix=".txt") as tmpfile:
2828
... # Prepare the sample data file
2929
... with Path(tmpfile.name).open(mode="w") as fp:
30+
... print("# x y z name", file=fp)
3031
... print(">", file=fp)
3132
... print("1.0 2.0 3.0 TEXT1 TEXT23", file=fp)
3233
... print("4.0 5.0 6.0 TEXT4 TEXT567", file=fp)
@@ -43,7 +44,8 @@ class _GMT_DATASET(ctp.Structure): # noqa: N801
4344
... print(ds.min[: ds.n_columns], ds.max[: ds.n_columns])
4445
... # The table
4546
... tbl = ds.table[0].contents
46-
... print(tbl.n_columns, tbl.n_segments, tbl.n_records)
47+
... print(tbl.n_columns, tbl.n_segments, tbl.n_records, tbl.n_headers)
48+
... print(tbl.header[: tbl.n_headers])
4749
... print(tbl.min[: tbl.n_columns], ds.max[: tbl.n_columns])
4850
... for i in range(tbl.n_segments):
4951
... seg = tbl.segment[i].contents
@@ -52,7 +54,8 @@ class _GMT_DATASET(ctp.Structure): # noqa: N801
5254
... print(seg.text[: seg.n_rows])
5355
1 3 2
5456
[1.0, 2.0, 3.0] [10.0, 11.0, 12.0]
55-
3 2 4
57+
3 2 4 1
58+
[b'x y z name']
5659
[1.0, 2.0, 3.0] [10.0, 11.0, 12.0]
5760
[1.0, 4.0]
5861
[2.0, 5.0]
@@ -169,6 +172,7 @@ def to_strings(self) -> np.ndarray[Any, np.dtype[np.str_]]:
169172

170173
def to_dataframe(
171174
self,
175+
header: int | None = None,
172176
column_names: pd.Index | None = None,
173177
dtype: type | Mapping[Any, type] | None = None,
174178
index_col: str | int | None = None,
@@ -187,6 +191,10 @@ def to_dataframe(
187191
----------
188192
column_names
189193
A list of column names.
194+
header
195+
Row number containing column names. ``header=None`` means not to parse the
196+
column names from table header. Ignored if the row number is larger than the
197+
number of headers in the table.
190198
dtype
191199
Data type. Can be a single type for all columns or a dictionary mapping
192200
column names to types.
@@ -207,6 +215,7 @@ def to_dataframe(
207215
>>> with GMTTempFile(suffix=".txt") as tmpfile:
208216
... # prepare the sample data file
209217
... with Path(tmpfile.name).open(mode="w") as fp:
218+
... print("# col1 col2 col3 colstr", file=fp)
210219
... print(">", file=fp)
211220
... print("1.0 2.0 3.0 TEXT1 TEXT23", file=fp)
212221
... print("4.0 5.0 6.0 TEXT4 TEXT567", file=fp)
@@ -218,12 +227,12 @@ def to_dataframe(
218227
... lib.call_module("read", f"{tmpfile.name} {vouttbl} -Td")
219228
... ds = lib.read_virtualfile(vouttbl, kind="dataset")
220229
... text = ds.contents.to_strings()
221-
... df = ds.contents.to_dataframe()
230+
... df = ds.contents.to_dataframe(header=0)
222231
>>> text
223232
array(['TEXT1 TEXT23', 'TEXT4 TEXT567', 'TEXT8 TEXT90',
224233
'TEXT123 TEXT456789'], dtype='<U18')
225234
>>> df
226-
0 1 2 3
235+
col1 col2 col3 colstr
227236
0 1.0 2.0 3.0 TEXT1 TEXT23
228237
1 4.0 5.0 6.0 TEXT4 TEXT567
229238
2 7.0 8.0 9.0 TEXT8 TEXT90
@@ -248,14 +257,19 @@ def to_dataframe(
248257
if len(textvector) != 0:
249258
vectors.append(pd.Series(data=textvector, dtype=pd.StringDtype()))
250259

260+
if header is not None:
261+
tbl = self.table[0].contents # Use the first table!
262+
if header < tbl.n_headers:
263+
column_names = tbl.header[header].decode().split()
264+
251265
if len(vectors) == 0:
252266
# Return an empty DataFrame if no columns are found.
253267
df = pd.DataFrame(columns=column_names)
254268
else:
255269
# Create a DataFrame object by concatenating multiple columns
256270
df = pd.concat(objs=vectors, axis="columns")
257271
if column_names is not None: # Assign column names
258-
df.columns = column_names
272+
df.columns = column_names[: df.shape[1]]
259273
if dtype is not None: # Set dtype for the whole dataset or individual columns
260274
df = df.astype(dtype)
261275
if index_col is not None: # Use a specific column as index

pygmt/tests/test_datatypes_dataset.py

Lines changed: 59 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,14 +40,14 @@ def dataframe_from_pandas(filepath_or_buffer, sep=r"\s+", comment="#", header=No
4040
return df
4141

4242

43-
def dataframe_from_gmt(fname):
43+
def dataframe_from_gmt(fname, **kwargs):
4444
"""
4545
Read tabular data as pandas.DataFrame using GMT virtual file.
4646
"""
4747
with Session() as lib:
4848
with lib.virtualfile_out(kind="dataset") as vouttbl:
4949
lib.call_module("read", f"{fname} {vouttbl} -Td")
50-
df = lib.virtualfile_to_dataset(vfname=vouttbl)
50+
df = lib.virtualfile_to_dataset(vfname=vouttbl, **kwargs)
5151
return df
5252

5353

@@ -84,6 +84,63 @@ def test_dataset_empty():
8484
pd.testing.assert_frame_equal(df, expected_df)
8585

8686

87+
def test_dataset_header():
88+
"""
89+
Test parsing column names from dataset header.
90+
"""
91+
with GMTTempFile(suffix=".txt") as tmpfile:
92+
with Path(tmpfile.name).open(mode="w") as fp:
93+
print("# lon lat z text", file=fp)
94+
print("1.0 2.0 3.0 TEXT1 TEXT23", file=fp)
95+
print("4.0 5.0 6.0 TEXT4 TEXT567", file=fp)
96+
97+
# Parse columne names from the first header line
98+
df = dataframe_from_gmt(tmpfile.name, header=0)
99+
assert df.columns.tolist() == ["lon", "lat", "z", "text"]
100+
# pd.read_csv() can't parse the header line with a leading '#'.
101+
# So, we need to skip the header line and manually set the column names.
102+
expected_df = dataframe_from_pandas(tmpfile.name, header=None)
103+
expected_df.columns = df.columns.tolist()
104+
pd.testing.assert_frame_equal(df, expected_df)
105+
106+
107+
def test_dataset_header_greater_than_nheaders():
108+
"""
109+
Test passing a header line number that is greater than the number of header lines.
110+
"""
111+
with GMTTempFile(suffix=".txt") as tmpfile:
112+
with Path(tmpfile.name).open(mode="w") as fp:
113+
print("# lon lat z text", file=fp)
114+
print("1.0 2.0 3.0 TEXT1 TEXT23", file=fp)
115+
print("4.0 5.0 6.0 TEXT4 TEXT567", file=fp)
116+
117+
# Parse column names from the second header line.
118+
df = dataframe_from_gmt(tmpfile.name, header=1)
119+
# There is only one header line, so the column names should be default.
120+
assert df.columns.tolist() == [0, 1, 2, 3]
121+
expected_df = dataframe_from_pandas(tmpfile.name, header=None)
122+
pd.testing.assert_frame_equal(df, expected_df)
123+
124+
125+
def test_dataset_header_too_many_names():
126+
"""
127+
Test passing a header line with more column names than the number of columns.
128+
"""
129+
with GMTTempFile(suffix=".txt") as tmpfile:
130+
with Path(tmpfile.name).open(mode="w") as fp:
131+
print("# lon lat z text1 text2", file=fp)
132+
print("1.0 2.0 3.0 TEXT1 TEXT23", file=fp)
133+
print("4.0 5.0 6.0 TEXT4 TEXT567", file=fp)
134+
135+
df = dataframe_from_gmt(tmpfile.name, header=0)
136+
assert df.columns.tolist() == ["lon", "lat", "z", "text1"]
137+
# pd.read_csv() can't parse the header line with a leading '#'.
138+
# So, we need to skip the header line and manually set the column names.
139+
expected_df = dataframe_from_pandas(tmpfile.name, header=None)
140+
expected_df.columns = df.columns.tolist()
141+
pd.testing.assert_frame_equal(df, expected_df)
142+
143+
87144
def test_dataset_to_strings_with_none_values():
88145
"""
89146
Test that None values in the trailing text doesn't raise an exception.

0 commit comments

Comments
 (0)