Skip to content

Commit 6deb388

Browse files
authored
Allow passing in pandas dataframes to x2sys_cross (#591)
Run crossover analysis directly on pandas.DataFrame inputs instead of having to write to tab-separated value (TSV) files first! Implemented by storing pandas.DataFrame data in a temporary file and passing this intermediate file to x2sys_cross. Need to do some file parsing to get the right file extension (suffix) for this to work. * Use tempfile_from_dftrack instead of tempfile_from_buffer * Don't use GMTTempFile, just generate random filename and write to it * Reduce git diff and make Windows tests pass by ignoring permission error * Test input two pandas dataframes to x2sys_cross with time column Renamed 'result' to 'table' to prevent pylint complaining about R0914: Too many local variables (16/15) (too-many-locals) * Improve docstring of x2sys_cross and tempfile_from_dftrack
1 parent c06fa44 commit 6deb388

File tree

2 files changed

+104
-21
lines changed

2 files changed

+104
-21
lines changed

pygmt/tests/test_x2sys_cross.py

Lines changed: 47 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@ def fixture_tracks():
3131
Load track data from the sample bathymetry file
3232
"""
3333
dataframe = load_sample_bathymetry()
34-
return [dataframe.query(expr="bathymetry > -20")] # reduce size of dataset
34+
dataframe.columns = ["x", "y", "z"] # longitude, latitude, bathymetry
35+
return [dataframe.query(expr="z > -20")] # reduce size of dataset
3536

3637

3738
def test_x2sys_cross_input_file_output_file(mock_x2sys_home):
@@ -76,25 +77,57 @@ def test_x2sys_cross_input_file_output_dataframe(mock_x2sys_home):
7677
def test_x2sys_cross_input_dataframe_output_dataframe(mock_x2sys_home, tracks):
7778
"""
7879
Run x2sys_cross by passing in one dataframe, and output external crossovers
79-
to a pandas.DataFrame. Not actually implemented yet, wait for
80-
https://github.com/GenericMappingTools/gmt/issues/3717
80+
to a pandas.DataFrame.
8181
"""
8282
with TemporaryDirectory(prefix="X2SYS", dir=os.getcwd()) as tmpdir:
8383
tag = os.path.basename(tmpdir)
8484
x2sys_init(tag=tag, fmtfile="xyz", force=True)
8585

86-
with pytest.raises(NotImplementedError):
87-
_ = x2sys_cross(tracks=tracks, tag=tag, coe="i", verbose="i")
86+
output = x2sys_cross(tracks=tracks, tag=tag, coe="i", verbose="i")
8887

89-
# assert isinstance(output, pd.DataFrame)
90-
# assert output.shape == (4, 12)
91-
# columns = list(output.columns)
92-
# assert columns[:6] == ["x", "y", "t_1", "t_2", "dist_1", "dist_2"]
93-
# assert columns[6:] == ["head_1","head_2","vel_1","vel_2","z_X","z_M"]
94-
# assert output.dtypes["t_1"].type == np.datetime64
95-
# assert output.dtypes["t_2"].type == np.datetime64
88+
assert isinstance(output, pd.DataFrame)
89+
assert output.shape == (14, 12)
90+
columns = list(output.columns)
91+
assert columns[:6] == ["x", "y", "i_1", "i_2", "dist_1", "dist_2"]
92+
assert columns[6:] == ["head_1", "head_2", "vel_1", "vel_2", "z_X", "z_M"]
93+
assert output.dtypes["i_1"].type == np.object_
94+
assert output.dtypes["i_2"].type == np.object_
95+
96+
return output
9697

97-
# return output
98+
99+
def test_x2sys_cross_input_two_dataframes(mock_x2sys_home):
100+
"""
101+
Run x2sys_cross by passing in two pandas.DataFrame tables with a time
102+
column, and output external crossovers to a pandas.DataFrame
103+
"""
104+
with TemporaryDirectory(prefix="X2SYS", dir=os.getcwd()) as tmpdir:
105+
tag = os.path.basename(tmpdir)
106+
x2sys_init(
107+
tag=tag, fmtfile="xyz", suffix="xyzt", units=["de", "se"], force=True
108+
)
109+
110+
# Add a time row to the x2sys fmtfile
111+
with open(file=os.path.join(tmpdir, "xyz.fmt"), mode="a") as fmtfile:
112+
fmtfile.write("time\ta\tN\t0\t1\t0\t%g\n")
113+
114+
# Create pandas.DataFrame track tables
115+
tracks = []
116+
for i in range(2):
117+
np.random.seed(seed=i)
118+
track = pd.DataFrame(data=np.random.rand(10, 3), columns=("x", "y", "z"))
119+
track["time"] = pd.date_range(start=f"2020-{i}1-01", periods=10, freq="ms")
120+
tracks.append(track)
121+
122+
output = x2sys_cross(tracks=tracks, tag=tag, coe="e", verbose="i")
123+
124+
assert isinstance(output, pd.DataFrame)
125+
assert output.shape == (30, 12)
126+
columns = list(output.columns)
127+
assert columns[:6] == ["x", "y", "t_1", "t_2", "dist_1", "dist_2"]
128+
assert columns[6:] == ["head_1", "head_2", "vel_1", "vel_2", "z_X", "z_M"]
129+
assert output.dtypes["t_1"].type == np.datetime64
130+
assert output.dtypes["t_2"].type == np.datetime64
98131

99132

100133
def test_x2sys_cross_input_two_filenames(mock_x2sys_home):
@@ -131,7 +164,7 @@ def test_x2sys_cross_invalid_tracks_input_type(tracks):
131164
Run x2sys_cross using tracks input that is not a pandas.DataFrame (matrix)
132165
or str (file) type, which would raise a GMTInvalidInput error.
133166
"""
134-
invalid_tracks = tracks[0].to_xarray().bathymetry
167+
invalid_tracks = tracks[0].to_xarray().z
135168
assert data_kind(invalid_tracks) == "grid"
136169
with pytest.raises(GMTInvalidInput):
137170
x2sys_cross(tracks=[invalid_tracks])

pygmt/x2sys.py

Lines changed: 57 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
GMT supplementary X2SYS module for crossover analysis.
33
"""
44
import contextlib
5+
import os
6+
from pathlib import Path
57

68
import pandas as pd
79

@@ -14,10 +16,45 @@
1416
dummy_context,
1517
fmt_docstring,
1618
kwargs_to_strings,
19+
unique_name,
1720
use_alias,
1821
)
1922

2023

24+
@contextlib.contextmanager
25+
def tempfile_from_dftrack(track, suffix):
26+
"""
27+
Saves pandas.DataFrame track table to a temporary tab-separated ASCII text
28+
file with a unique name (to prevent clashes when running x2sys_cross),
29+
adding a suffix extension to the end.
30+
31+
Parameters
32+
----------
33+
track : pandas.DataFrame
34+
A table holding track data with coordinate (x, y) or (lon, lat) values,
35+
and (optionally) time (t).
36+
suffix : str
37+
File extension, e.g. xyz, tsv, etc.
38+
39+
Yields
40+
------
41+
tmpfilename : str
42+
A temporary tab-separated value file with a unique name holding the
43+
track data. E.g. 'track-1a2b3c4.tsv'.
44+
"""
45+
try:
46+
tmpfilename = f"track-{unique_name()[:7]}.{suffix}"
47+
track.to_csv(
48+
path_or_buf=tmpfilename,
49+
sep="\t",
50+
index=False,
51+
date_format="%Y-%m-%dT%H:%M:%S.%fZ",
52+
)
53+
yield tmpfilename
54+
finally:
55+
os.remove(tmpfilename)
56+
57+
2158
@fmt_docstring
2259
@use_alias(
2360
D="fmtfile",
@@ -158,9 +195,10 @@ def x2sys_cross(tracks=None, outfile=None, **kwargs):
158195
159196
Parameters
160197
----------
161-
tracks : str or list
198+
tracks : pandas.DataFrame or str or list
162199
A table or a list of tables with (x, y) or (lon, lat) values in the
163-
first two columns. Supported formats are ASCII, native binary, or
200+
first two columns. Track(s) can be provided as pandas DataFrame tables
201+
or file names. Supported file formats are ASCII, native binary, or
164202
COARDS netCDF 1-D data. More columns may also be present.
165203
166204
If the filenames are missing their file extension, we will append the
@@ -263,8 +301,20 @@ def x2sys_cross(tracks=None, outfile=None, **kwargs):
263301
if kind == "file":
264302
file_contexts.append(dummy_context(track))
265303
elif kind == "matrix":
266-
raise NotImplementedError(f"{type(track)} inputs are not supported yet")
267-
# file_contexts.append(lib.virtualfile_from_matrix(track.values))
304+
# find suffix (-E) of trackfiles used (e.g. xyz, csv, etc) from
305+
# $X2SYS_HOME/TAGNAME/TAGNAME.tag file
306+
lastline = (
307+
Path(os.environ["X2SYS_HOME"], kwargs["T"], f"{kwargs['T']}.tag")
308+
.read_text()
309+
.strip()
310+
.split("\n")[-1]
311+
) # e.g. "-Dxyz -Etsv -I1/1"
312+
for item in sorted(lastline.split()): # sort list alphabetically
313+
if item.startswith(("-E", "-D")): # prefer -Etsv over -Dxyz
314+
suffix = item[2:] # e.g. tsv (1st choice) or xyz (2nd choice)
315+
316+
# Save pandas.DataFrame track data to temporary file
317+
file_contexts.append(tempfile_from_dftrack(track=track, suffix=suffix))
268318
else:
269319
raise GMTInvalidInput(f"Unrecognized data type: {type(track)}")
270320

@@ -287,8 +337,8 @@ def x2sys_cross(tracks=None, outfile=None, **kwargs):
287337
parse_dates=[2, 3], # Datetimes on 3rd and 4th column
288338
)
289339
# Remove the "# " from "# x" in the first column
290-
result = table.rename(columns={table.columns[0]: table.columns[0][2:]})
340+
table = table.rename(columns={table.columns[0]: table.columns[0][2:]})
291341
elif outfile != tmpfile.name: # if outfile is set, output in outfile only
292-
result = None
342+
table = None
293343

294-
return result
344+
return table

0 commit comments

Comments
 (0)