diff --git a/pygmt/clib/session.py b/pygmt/clib/session.py index eba799660f3..12f7df5eb45 100644 --- a/pygmt/clib/session.py +++ b/pygmt/clib/session.py @@ -1755,9 +1755,10 @@ def virtualfile_from_stringio( @deprecate_parameter( "required_data", "required", "v0.16.0", remove_version="v0.20.0" ) - def virtualfile_in( # noqa: PLR0912 + def virtualfile_in( self, check_kind=None, + kind=None, data=None, x=None, y=None, @@ -1847,7 +1848,9 @@ def virtualfile_in( # noqa: PLR0912 ) mincols = 3 - kind = data_kind(data, required=required) + # Determine the data kind if not given. + if kind is None: + kind = data_kind(data, required=required, check_kind=check_kind) _validate_data_input( data=data, x=x, @@ -1858,16 +1861,6 @@ def virtualfile_in( # noqa: PLR0912 kind=kind, ) - if check_kind: - valid_kinds = ("file", "arg") if required is False else ("file",) - if check_kind == "raster": - valid_kinds += ("grid", "image") - elif check_kind == "vector": - valid_kinds += ("empty", "matrix", "vectors", "geojson") - if kind not in valid_kinds: - msg = f"Unrecognized data type for {check_kind}: {type(data)}." - raise GMTInvalidInput(msg) - # Decide which virtualfile_from_ function to use _virtualfile_from = { "arg": contextlib.nullcontext, diff --git a/pygmt/helpers/utils.py b/pygmt/helpers/utils.py index 387d61e03b3..b0209f5b814 100644 --- a/pygmt/helpers/utils.py +++ b/pygmt/helpers/utils.py @@ -41,6 +41,11 @@ "ISO-8859-16", ] +# Type hints for the list of data kinds. +Kind = Literal[ + "arg", "empty", "file", "geojson", "grid", "image", "matrix", "stringio", "vectors" +] + def _validate_data_input( # noqa: PLR0912 data=None, x=None, y=None, z=None, required=True, mincols=2, kind=None @@ -272,11 +277,11 @@ def _check_encoding(argstr: str) -> Encoding: return "ISOLatin1+" -def data_kind( - data: Any, required: bool = True -) -> Literal[ - "arg", "empty", "file", "geojson", "grid", "image", "matrix", "stringio", "vectors" -]: +def data_kind( # noqa: PLR0912 + data: Any, + required: bool = True, + check_kind: Kind | Sequence[Kind] | Literal["raster", "vector"] | None = None, +) -> Kind: r""" Check the kind of data that is provided to a module. @@ -307,6 +312,14 @@ def data_kind( required Whether 'data' is required. Set to ``False`` when dealing with optional virtual files. + check_kind + Used to validate the type of data that can be passed in. Valid values are: + + - Any recognized data kind + - A list/tuple of recognized data kinds + - ``"raster"``: shorthand for a sequence of raster-like data kinds + - ``"vector"``: shorthand for a sequence of vector-like data kinds + - ``None``: means no validatation. Returns ------- @@ -414,6 +427,24 @@ def data_kind( kind = "matrix" case _: # Fall back to "vectors" if data is None and required=True. kind = "vectors" + + # Now start to check if the data kind is valid. + if check_kind is not None: + valid_kinds = ("file", "arg") if required is False else ("file",) + match check_kind: + case "raster": + valid_kinds += ("grid", "image") + case "vector": + valid_kinds += ("empty", "matrix", "vectors", "geojson") + case str(): + valid_kinds = (check_kind,) + case list() | tuple(): + valid_kinds = check_kind + + if kind not in valid_kinds: + msg = f"Unrecognized data type: {type(data)}." + raise GMTInvalidInput(msg) + return kind # type: ignore[return-value] diff --git a/pygmt/src/grdcut.py b/pygmt/src/grdcut.py index 2d5b1f0e5c9..b200a5c118b 100644 --- a/pygmt/src/grdcut.py +++ b/pygmt/src/grdcut.py @@ -117,7 +117,7 @@ def grdcut( raise GMTInvalidInput(msg) # Determine the output data kind based on the input data kind. - match inkind := data_kind(grid): + match inkind := data_kind(grid, check_kind="raster"): case "grid" | "image": outkind = inkind case "file": @@ -128,7 +128,7 @@ def grdcut( with Session() as lib: with ( - lib.virtualfile_in(check_kind="raster", data=grid) as vingrd, + lib.virtualfile_in(data=grid, kind=inkind) as vingrd, lib.virtualfile_out(kind=outkind, fname=outgrid) as voutgrd, ): kwargs["G"] = voutgrd diff --git a/pygmt/src/legend.py b/pygmt/src/legend.py index 2cb2eddcf95..df2636eb108 100644 --- a/pygmt/src/legend.py +++ b/pygmt/src/legend.py @@ -89,14 +89,11 @@ def legend( if kwargs.get("F") is None: kwargs["F"] = box - kind = data_kind(spec) - if kind not in {"empty", "file", "stringio"}: - msg = f"Unrecognized data type: {type(spec)}" - raise GMTInvalidInput(msg) + kind = data_kind(spec, check_kind=("empty", "file", "stringio")) if kind == "file" and is_nonstr_iter(spec): msg = "Only one legend specification file is allowed." raise GMTInvalidInput(msg) with Session() as lib: - with lib.virtualfile_in(data=spec, required=False) as vintbl: + with lib.virtualfile_in(data=spec, required=False, kind=kind) as vintbl: lib.call_module(module="legend", args=build_arg_list(kwargs, infile=vintbl)) diff --git a/pygmt/src/meca.py b/pygmt/src/meca.py index 0b4cbc3f23a..4111be1e162 100644 --- a/pygmt/src/meca.py +++ b/pygmt/src/meca.py @@ -49,7 +49,7 @@ def _preprocess_spec(spec, colnames, override_cols): Dictionary of column names and values to override in the input data. Only makes sense if ``spec`` is a dict or :class:`pandas.DataFrame`. """ - kind = data_kind(spec) # Determine the kind of the input data. + kind = data_kind(spec, check_kind="vector") # Determine the kind of the input data. # Convert pandas.DataFrame and numpy.ndarray to dict. if isinstance(spec, pd.DataFrame): @@ -360,5 +360,5 @@ def meca( # noqa: PLR0913 kwargs["A"] = _auto_offset(spec) kwargs["S"] = f"{_convention.code}{scale}" with Session() as lib: - with lib.virtualfile_in(check_kind="vector", data=spec) as vintbl: + with lib.virtualfile_in(data=spec) as vintbl: lib.call_module(module="meca", args=build_arg_list(kwargs, infile=vintbl)) diff --git a/pygmt/src/plot.py b/pygmt/src/plot.py index 35dab4aa009..5fcaffa6aea 100644 --- a/pygmt/src/plot.py +++ b/pygmt/src/plot.py @@ -232,8 +232,9 @@ def plot( # noqa: PLR0912 # parameter. self._activate_figure() - kind = data_kind(data) + kind = data_kind(data, check_kind="vector") if kind == "empty": # Data is given via a series of vectors. + kind = "vectors" data = {"x": x, "y": y} # Parameters for vector styles if ( @@ -280,5 +281,5 @@ def plot( # noqa: PLR0912 kwargs["S"] = "s0.2c" with Session() as lib: - with lib.virtualfile_in(check_kind="vector", data=data) as vintbl: + with lib.virtualfile_in(data=data, kind=kind) as vintbl: lib.call_module(module="plot", args=build_arg_list(kwargs, infile=vintbl)) diff --git a/pygmt/src/plot3d.py b/pygmt/src/plot3d.py index 491ebbcf9df..17028cfba3c 100644 --- a/pygmt/src/plot3d.py +++ b/pygmt/src/plot3d.py @@ -210,8 +210,9 @@ def plot3d( # noqa: PLR0912 # parameter. self._activate_figure() - kind = data_kind(data) + kind = data_kind(data, check_kind="vector") if kind == "empty": # Data is given via a series of vectors. + kind = "vectors" data = {"x": x, "y": y, "z": z} # Parameters for vector styles if ( @@ -259,5 +260,5 @@ def plot3d( # noqa: PLR0912 kwargs["S"] = "u0.2c" with Session() as lib: - with lib.virtualfile_in(check_kind="vector", data=data, mincols=3) as vintbl: + with lib.virtualfile_in(data=data, mincols=3, kind=kind) as vintbl: lib.call_module(module="plot3d", args=build_arg_list(kwargs, infile=vintbl)) diff --git a/pygmt/src/text.py b/pygmt/src/text.py index 99a3180415a..7a0c490ce75 100644 --- a/pygmt/src/text.py +++ b/pygmt/src/text.py @@ -43,7 +43,7 @@ w="wrap", ) @kwargs_to_strings(R="sequence", c="sequence_comma", p="sequence") -def text_( # noqa: PLR0912 +def text_( # noqa: PLR0912, PLR0915 self, textfiles: PathLike | TableLike | None = None, x=None, @@ -192,7 +192,7 @@ def text_( # noqa: PLR0912 raise GMTInvalidInput(msg) data_is_required = position is None - kind = data_kind(textfiles, required=data_is_required) + kind = data_kind(textfiles, required=data_is_required, check_kind="vector") if position is not None and (text is None or is_nonstr_iter(text)): msg = "'text' can't be None or array when 'position' is given." @@ -226,6 +226,7 @@ def text_( # noqa: PLR0912 confdict = {} data = None if kind == "empty": + kind = "vectors" data = {"x": x, "y": y} for arg, flag, name in array_args: @@ -262,7 +263,9 @@ def text_( # noqa: PLR0912 with Session() as lib: with lib.virtualfile_in( - check_kind="vector", data=textfiles or data, required=data_is_required + data=textfiles or data, + required=data_is_required, + kind=kind, ) as vintbl: lib.call_module( module="text", diff --git a/pygmt/src/x2sys_cross.py b/pygmt/src/x2sys_cross.py index d502d72c190..ff21afd0ac0 100644 --- a/pygmt/src/x2sys_cross.py +++ b/pygmt/src/x2sys_cross.py @@ -195,7 +195,7 @@ def x2sys_cross( file_contexts: list[contextlib.AbstractContextManager[Any]] = [] for track in tracks: - match data_kind(track): + match data_kind(track, check_kind="vector"): case "file": file_contexts.append(contextlib.nullcontext(track)) case "vectors":