Skip to content

Commit a5d8b14

Browse files
seismanweiji14
andauthored
Wrap GMT's standard data type GMT_DATASET for table input/output (#2729)
Co-authored-by: Wei Ji <[email protected]>
1 parent f08cb94 commit a5d8b14

File tree

1 file changed

+205
-1
lines changed

1 file changed

+205
-1
lines changed

pygmt/datatypes/dataset.py

Lines changed: 205 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,211 @@
33
"""
44

55
import ctypes as ctp
6+
from typing import ClassVar
7+
8+
import numpy as np
9+
import pandas as pd
610

711

812
class _GMT_DATASET(ctp.Structure): # noqa: N801
9-
pass
13+
"""
14+
GMT dataset structure for holding multiple tables (files).
15+
16+
This class is only meant for internal use by PyGMT and is not exposed to users.
17+
See the GMT source code gmt_resources.h for the original C struct definitions.
18+
19+
Examples
20+
--------
21+
>>> from pygmt.helpers import GMTTempFile
22+
>>> from pygmt.clib import Session
23+
>>>
24+
>>> with GMTTempFile(suffix=".txt") as tmpfile:
25+
... # Prepare the sample data file
26+
... with open(tmpfile.name, mode="w") as fp:
27+
... print(">", file=fp)
28+
... print("1.0 2.0 3.0 TEXT1 TEXT23", file=fp)
29+
... print("4.0 5.0 6.0 TEXT4 TEXT567", file=fp)
30+
... print(">", file=fp)
31+
... print("7.0 8.0 9.0 TEXT8 TEXT90", file=fp)
32+
... print("10.0 11.0 12.0 TEXT123 TEXT456789", file=fp)
33+
... # Read the data file
34+
... with Session() as lib:
35+
... with lib.virtualfile_out(kind="dataset") as vouttbl:
36+
... lib.call_module("read", f"{tmpfile.name} {vouttbl} -Td")
37+
... # The dataset
38+
... ds = lib.read_virtualfile(vouttbl, kind="dataset").contents
39+
... print(ds.n_tables, ds.n_columns, ds.n_segments)
40+
... print(ds.min[: ds.n_columns], ds.max[: ds.n_columns])
41+
... # The table
42+
... tbl = ds.table[0].contents
43+
... print(tbl.n_columns, tbl.n_segments, tbl.n_records)
44+
... print(tbl.min[: tbl.n_columns], ds.max[: tbl.n_columns])
45+
... for i in range(tbl.n_segments):
46+
... seg = tbl.segment[i].contents
47+
... for j in range(seg.n_columns):
48+
... print(seg.data[j][: seg.n_rows])
49+
... print(seg.text[: seg.n_rows])
50+
1 3 2
51+
[1.0, 2.0, 3.0] [10.0, 11.0, 12.0]
52+
3 2 4
53+
[1.0, 2.0, 3.0] [10.0, 11.0, 12.0]
54+
[1.0, 4.0]
55+
[2.0, 5.0]
56+
[3.0, 6.0]
57+
[b'TEXT1 TEXT23', b'TEXT4 TEXT567']
58+
[7.0, 10.0]
59+
[8.0, 11.0]
60+
[9.0, 12.0]
61+
[b'TEXT8 TEXT90', b'TEXT123 TEXT456789']
62+
"""
63+
64+
class _GMT_DATATABLE(ctp.Structure): # noqa: N801
65+
"""
66+
GMT datatable structure for holding a table with multiple segments.
67+
"""
68+
69+
class _GMT_DATASEGMENT(ctp.Structure): # noqa: N801
70+
"""
71+
GMT datasegment structure for holding a segment with multiple columns.
72+
"""
73+
74+
_fields_: ClassVar = [
75+
# Number of rows/records in this segment
76+
("n_rows", ctp.c_uint64),
77+
# Number of fields in each record
78+
("n_columns", ctp.c_uint64),
79+
# Minimum coordinate for each column
80+
("min", ctp.POINTER(ctp.c_double)),
81+
# Maximum coordinate for each column
82+
("max", ctp.POINTER(ctp.c_double)),
83+
# Data x, y, and possibly other columns
84+
("data", ctp.POINTER(ctp.POINTER(ctp.c_double))),
85+
# Label string (if applicable)
86+
("label", ctp.c_char_p),
87+
# Segment header (if applicable)
88+
("header", ctp.c_char_p),
89+
# text beyond the data
90+
("text", ctp.POINTER(ctp.c_char_p)),
91+
# Book-keeping variables "hidden" from the API
92+
("hidden", ctp.c_void_p),
93+
]
94+
95+
_fields_: ClassVar = [
96+
# Number of file header records (0 if no header)
97+
("n_headers", ctp.c_uint),
98+
# Number of columns (fields) in each record
99+
("n_columns", ctp.c_uint64),
100+
# Number of segments in the array
101+
("n_segments", ctp.c_uint64),
102+
# Total number of data records across all segments
103+
("n_records", ctp.c_uint64),
104+
# Minimum coordinate for each column
105+
("min", ctp.POINTER(ctp.c_double)),
106+
# Maximum coordinate for each column
107+
("max", ctp.POINTER(ctp.c_double)),
108+
# Array with all file header records, if any
109+
("header", ctp.POINTER(ctp.c_char_p)),
110+
# Pointer to array of segments
111+
("segment", ctp.POINTER(ctp.POINTER(_GMT_DATASEGMENT))),
112+
# Book-keeping variables "hidden" from the API
113+
("hidden", ctp.c_void_p),
114+
]
115+
116+
_fields_: ClassVar = [
117+
# The total number of tables (files) contained
118+
("n_tables", ctp.c_uint64),
119+
# The number of data columns
120+
("n_columns", ctp.c_uint64),
121+
# The total number of segments across all tables
122+
("n_segments", ctp.c_uint64),
123+
# The total number of data records across all tables
124+
("n_records", ctp.c_uint64),
125+
# Minimum coordinate for each column
126+
("min", ctp.POINTER(ctp.c_double)),
127+
# Maximum coordinate for each column
128+
("max", ctp.POINTER(ctp.c_double)),
129+
# Pointer to array of tables
130+
("table", ctp.POINTER(ctp.POINTER(_GMT_DATATABLE))),
131+
# The datatype (numerical, text, or mixed) of this dataset
132+
("type", ctp.c_int32),
133+
# The geometry of this dataset
134+
("geometry", ctp.c_int32),
135+
# To store a referencing system string in PROJ.4 format
136+
("ProjRefPROJ4", ctp.c_char_p),
137+
# To store a referencing system string in WKT format
138+
("ProjRefWKT", ctp.c_char_p),
139+
# To store a referencing system EPSG code
140+
("ProjRefEPSG", ctp.c_int),
141+
# Book-keeping variables "hidden" from the API
142+
("hidden", ctp.c_void_p),
143+
]
144+
145+
def to_dataframe(self) -> pd.DataFrame:
146+
"""
147+
Convert a _GMT_DATASET object to a :class:`pandas.DataFrame` object.
148+
149+
Currently, the number of columns in all segments of all tables are assumed to be
150+
the same. The same column in all segments of all tables are concatenated. The
151+
trailing text column is also concatenated as a single string column.
152+
153+
Returns
154+
-------
155+
df
156+
A :class:`pandas.DataFrame` object.
157+
158+
Examples
159+
--------
160+
>>> from pygmt.helpers import GMTTempFile
161+
>>> from pygmt.clib import Session
162+
>>>
163+
>>> with GMTTempFile(suffix=".txt") as tmpfile:
164+
... # prepare the sample data file
165+
... with open(tmpfile.name, mode="w") as fp:
166+
... print(">", file=fp)
167+
... print("1.0 2.0 3.0 TEXT1 TEXT23", file=fp)
168+
... print("4.0 5.0 6.0 TEXT4 TEXT567", file=fp)
169+
... print(">", file=fp)
170+
... print("7.0 8.0 9.0 TEXT8 TEXT90", file=fp)
171+
... print("10.0 11.0 12.0 TEXT123 TEXT456789", file=fp)
172+
... with Session() as lib:
173+
... with lib.virtualfile_out(kind="dataset") as vouttbl:
174+
... lib.call_module("read", f"{tmpfile.name} {vouttbl} -Td")
175+
... ds = lib.read_virtualfile(vouttbl, kind="dataset")
176+
... df = ds.contents.to_dataframe()
177+
>>> df
178+
0 1 2 3
179+
0 1.0 2.0 3.0 TEXT1 TEXT23
180+
1 4.0 5.0 6.0 TEXT4 TEXT567
181+
2 7.0 8.0 9.0 TEXT8 TEXT90
182+
3 10.0 11.0 12.0 TEXT123 TEXT456789
183+
>>> df.dtypes.to_list()
184+
[dtype('float64'), dtype('float64'), dtype('float64'), string[python]]
185+
"""
186+
# Deal with numeric columns
187+
vectors = []
188+
for icol in range(self.n_columns):
189+
colvector = []
190+
for itbl in range(self.n_tables):
191+
dtbl = self.table[itbl].contents
192+
for iseg in range(dtbl.n_segments):
193+
dseg = dtbl.segment[iseg].contents
194+
colvector.append(
195+
np.ctypeslib.as_array(dseg.data[icol], shape=(dseg.n_rows,))
196+
)
197+
vectors.append(pd.Series(data=np.concatenate(colvector)))
198+
199+
# Deal with trailing text column
200+
textvector = []
201+
for itbl in range(self.n_tables):
202+
dtbl = self.table[itbl].contents
203+
for iseg in range(dtbl.n_segments):
204+
dseg = dtbl.segment[iseg].contents
205+
if dseg.text:
206+
textvector.extend(dseg.text[: dseg.n_rows])
207+
if textvector:
208+
vectors.append(
209+
pd.Series(data=np.char.decode(textvector), dtype=pd.StringDtype())
210+
)
211+
212+
df = pd.concat(objs=vectors, axis=1)
213+
return df

0 commit comments

Comments
 (0)