Skip to content

Commit 8220180

Browse files
authored
pygmt.grdtrack: Add 'output_type' parameter for output in pandas/numpy/file formats (#3106)
1 parent 83b1a12 commit 8220180

File tree

1 file changed

+42
-38
lines changed

1 file changed

+42
-38
lines changed

pygmt/src/grdtrack.py

Lines changed: 42 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,18 @@
22
grdtrack - Sample grids at specified (x,y) locations.
33
"""
44

5+
from typing import Literal
6+
7+
import numpy as np
58
import pandas as pd
69
from pygmt.clib import Session
710
from pygmt.exceptions import GMTInvalidInput
811
from pygmt.helpers import (
9-
GMTTempFile,
1012
build_arg_string,
1113
fmt_docstring,
1214
kwargs_to_strings,
1315
use_alias,
16+
validate_output_table_type,
1417
)
1518

1619
__doctest_skip__ = ["grdtrack"]
@@ -44,7 +47,14 @@
4447
w="wrap",
4548
)
4649
@kwargs_to_strings(R="sequence", S="sequence", i="sequence_comma", o="sequence_comma")
47-
def grdtrack(grid, points=None, newcolname=None, outfile=None, **kwargs):
50+
def grdtrack(
51+
grid,
52+
points=None,
53+
output_type: Literal["pandas", "numpy", "file"] = "pandas",
54+
outfile: str | None = None,
55+
newcolname=None,
56+
**kwargs,
57+
) -> pd.DataFrame | np.ndarray | None:
4858
r"""
4959
Sample grids at specified (x,y) locations.
5060
@@ -73,15 +83,12 @@ def grdtrack(grid, points=None, newcolname=None, outfile=None, **kwargs):
7383
points : str, {table-like}
7484
Pass in either a file name to an ASCII data table, a 2-D
7585
{table-classes}.
76-
86+
{output_type}
87+
{outfile}
7788
newcolname : str
7889
Required if ``points`` is a :class:`pandas.DataFrame`. The name for the
7990
new column in the track :class:`pandas.DataFrame` table where the
8091
sampled values will be placed.
81-
82-
outfile : str
83-
The file name for the output ASCII file.
84-
8592
resample : str
8693
**f**\|\ **p**\|\ **m**\|\ **r**\|\ **R**\ [**+l**]
8794
For track resampling (if ``crossprofile`` or ``profile`` are set) we
@@ -258,13 +265,13 @@ def grdtrack(grid, points=None, newcolname=None, outfile=None, **kwargs):
258265
259266
Returns
260267
-------
261-
track: pandas.DataFrame or None
262-
Return type depends on whether the ``outfile`` parameter is set:
268+
ret
269+
Return type depends on ``outfile`` and ``output_type``:
263270
264-
- :class:`pandas.DataFrame` table with (x, y, ..., newcolname) if
265-
``outfile`` is not set
266-
- None if ``outfile`` is set (track output will be stored in file set
267-
by ``outfile``)
271+
- ``None`` if ``outfile`` is set (output will be stored in file set by
272+
``outfile``)
273+
- :class:`pandas.DataFrame` or :class:`numpy.ndarray` if ``outfile`` is not set
274+
(depends on ``output_type``)
268275
269276
Example
270277
-------
@@ -291,30 +298,27 @@ def grdtrack(grid, points=None, newcolname=None, outfile=None, **kwargs):
291298
if hasattr(points, "columns") and newcolname is None:
292299
raise GMTInvalidInput("Please pass in a str to 'newcolname'")
293300

294-
with GMTTempFile(suffix=".csv") as tmpfile:
295-
with Session() as lib:
296-
with (
297-
lib.virtualfile_in(check_kind="raster", data=grid) as vingrd,
298-
lib.virtualfile_in(
299-
check_kind="vector", data=points, required_data=False
300-
) as vintbl,
301-
):
302-
kwargs["G"] = vingrd
303-
if outfile is None: # Output to tmpfile if outfile is not set
304-
outfile = tmpfile.name
305-
lib.call_module(
306-
module="grdtrack",
307-
args=build_arg_string(kwargs, infile=vintbl, outfile=outfile),
308-
)
301+
output_type = validate_output_table_type(output_type, outfile=outfile)
309302

310-
# Read temporary csv output to a pandas table
311-
if outfile == tmpfile.name: # if user did not set outfile, return pd.DataFrame
312-
try:
313-
column_names = [*points.columns.to_list(), newcolname]
314-
result = pd.read_csv(tmpfile.name, sep="\t", names=column_names)
315-
except AttributeError: # 'str' object has no attribute 'columns'
316-
result = pd.read_csv(tmpfile.name, sep="\t", header=None, comment=">")
317-
elif outfile != tmpfile.name: # return None if outfile set, output in outfile
318-
result = None
303+
column_names = None
304+
if output_type == "pandas" and isinstance(points, pd.DataFrame):
305+
column_names = [*points.columns.to_list(), newcolname]
319306

320-
return result
307+
with Session() as lib:
308+
with (
309+
lib.virtualfile_in(check_kind="raster", data=grid) as vingrd,
310+
lib.virtualfile_in(
311+
check_kind="vector", data=points, required_data=False
312+
) as vintbl,
313+
lib.virtualfile_out(kind="dataset", fname=outfile) as vouttbl,
314+
):
315+
kwargs["G"] = vingrd
316+
lib.call_module(
317+
module="grdtrack",
318+
args=build_arg_string(kwargs, infile=vintbl, outfile=vouttbl),
319+
)
320+
return lib.virtualfile_to_dataset(
321+
output_type=output_type,
322+
vfname=vouttbl,
323+
column_names=column_names,
324+
)

0 commit comments

Comments
 (0)