Skip to content

Commit cf1d6df

Browse files
committed
Extend the data_kind function to validate the kinds
1 parent aa7c658 commit cf1d6df

File tree

1 file changed

+36
-5
lines changed

1 file changed

+36
-5
lines changed

pygmt/helpers/utils.py

Lines changed: 36 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,11 @@
4141
"ISO-8859-16",
4242
]
4343

44+
# Type hints for the list of data kinds.
45+
Kind = Literal[
46+
"arg", "empty", "file", "geojson", "grid", "image", "matrix", "stringio", "vectors"
47+
]
48+
4449

4550
def _validate_data_input( # noqa: PLR0912
4651
data=None, x=None, y=None, z=None, required=True, mincols=2, kind=None
@@ -272,11 +277,11 @@ def _check_encoding(argstr: str) -> Encoding:
272277
return "ISOLatin1+"
273278

274279

275-
def data_kind(
276-
data: Any, required: bool = True
277-
) -> Literal[
278-
"arg", "empty", "file", "geojson", "grid", "image", "matrix", "stringio", "vectors"
279-
]:
280+
def data_kind( # noqa: PLR0912
281+
data: Any,
282+
required: bool = True,
283+
check_kind: Kind | Sequence[Kind] | Literal["raster", "vector"] | None = None,
284+
) -> Kind:
280285
r"""
281286
Check the kind of data that is provided to a module.
282287
@@ -307,6 +312,14 @@ def data_kind(
307312
required
308313
Whether 'data' is required. Set to ``False`` when dealing with optional virtual
309314
files.
315+
check_kind
316+
Used to validate the type of data that can be passed in. Valid values are:
317+
318+
- Any recognized data kind
319+
- A list/tuple of recognized data kinds
320+
- ``"raster"``: shorthand for a sequence of raster-like data kinds
321+
- ``"vector"``: shorthand for a sequence of vector-like data kinds
322+
- ``None``: means no validatation.
310323
311324
Returns
312325
-------
@@ -414,6 +427,24 @@ def data_kind(
414427
kind = "matrix"
415428
case _: # Fall back to "vectors" if data is None and required=True.
416429
kind = "vectors"
430+
431+
# Now start to check if the data kind is valid.
432+
if check_kind is not None:
433+
valid_kinds = ("file", "arg") if required is False else ("file",)
434+
match check_kind:
435+
case "raster":
436+
valid_kinds += ("grid", "image")
437+
case "vector":
438+
valid_kinds += ("empty", "matrix", "vectors", "geojson")
439+
case str():
440+
valid_kinds = (check_kind,)
441+
case list() | tuple():
442+
valid_kinds = check_kind
443+
444+
if kind not in valid_kinds:
445+
msg = f"Unrecognized data type: {type(data)}."
446+
raise GMTInvalidInput(msg)
447+
417448
return kind # type: ignore[return-value]
418449

419450

0 commit comments

Comments
 (0)