Skip to content

Commit 722dae9

Browse files
weiji14Meghan Jonesseisman
authored
Allow passing None explicitly to pygmt functions Part 1 (#1857)
Implements a more robust check for None values in pygmt functions. * Let grdimage shading=None or False work Refactor grdimage to check `if "I" in kwargs` to using `if kwargs.get("I") is not None`. * Let grd2cpt's categorical, cyclic and output work with None input * Let grd2xyz's outcols work with None input Specifically when output_type="pandas" too. * Let grdgradient's tiles, normalize and outgrid work with None input * Let grdview's drapegrid work with None inputs * Let makecpt's categorical, cyclic and output work with None inputs * Let plot's style, color, intensity and transparency work with None input * Let plot3d's style, color, intensity & transparency work with None input * Let solar's T work with None input * Let transparency work with 0, None and False input * Let project's center, convention and generate work with None inputs * Let velo's spec work with None inputs Or rather, catch it properly if someone uses spec=None. * Update pygmt/src/grdgradient.py using walrus operator Co-authored-by: Meghan Jones <[email protected]> Co-authored-by: Dongdong Tian <[email protected]>
1 parent 61781e4 commit 722dae9

16 files changed

+72
-34
lines changed

pygmt/src/grd2cpt.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -160,14 +160,14 @@ def grd2cpt(grid, **kwargs):
160160
``categorical=True``.
161161
{V}
162162
"""
163-
if "W" in kwargs and "Ww" in kwargs:
163+
if kwargs.get("W") is not None and kwargs.get("Ww") is not None:
164164
raise GMTInvalidInput("Set only categorical or cyclic to True, not both.")
165165
with Session() as lib:
166166
file_context = lib.virtualfile_from_data(check_kind="raster", data=grid)
167167
with file_context as infile:
168-
if "H" not in kwargs: # if no output is set
168+
if kwargs.get("H") is None: # if no output is set
169169
arg_str = build_arg_string(kwargs, infile=infile)
170-
if "H" in kwargs: # if output is set
170+
else: # if output is set
171171
outfile, kwargs["H"] = kwargs["H"], True
172172
if not outfile or not isinstance(outfile, str):
173173
raise GMTInvalidInput("'output' should be a proper file name.")

pygmt/src/grd2xyz.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ def grd2xyz(grid, output_type="pandas", outfile=None, **kwargs):
159159
elif outfile is None and output_type == "file":
160160
raise GMTInvalidInput("Must specify 'outfile' for ASCII output.")
161161

162-
if "o" in kwargs and output_type == "pandas":
162+
if kwargs.get("o") is not None and output_type == "pandas":
163163
raise GMTInvalidInput(
164164
"If 'outcols' is specified, 'output_type' must be either 'numpy'"
165165
"or 'file'."

pygmt/src/grdgradient.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ def grdgradient(grid, **kwargs):
164164
>>> new_grid = pygmt.grdgradient(grid=grid, azimuth=10)
165165
"""
166166
with GMTTempFile(suffix=".nc") as tmpfile:
167-
if "Q" in kwargs and "N" not in kwargs:
167+
if kwargs.get("Q") is not None and kwargs.get("N") is None:
168168
raise GMTInvalidInput("""Must specify normalize if tiles is specified.""")
169169
if not args_in_kwargs(args=["A", "D", "E"], kwargs=kwargs):
170170
raise GMTInvalidInput(
@@ -174,9 +174,8 @@ def grdgradient(grid, **kwargs):
174174
with Session() as lib:
175175
file_context = lib.virtualfile_from_data(check_kind="raster", data=grid)
176176
with file_context as infile:
177-
if "G" not in kwargs: # if outgrid is unset, output to tempfile
178-
kwargs.update({"G": tmpfile.name})
179-
outgrid = kwargs["G"]
177+
if (outgrid := kwargs.get("G")) is None:
178+
kwargs["G"] = outgrid = tmpfile.name # output to tmpfile
180179
lib.call_module("grdgradient", build_arg_string(kwargs, infile=infile))
181180

182181
return load_dataarray(outgrid) if outgrid == tmpfile.name else None

pygmt/src/grdimage.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ def grdimage(self, grid, **kwargs):
166166
file_context = lib.virtualfile_from_data(check_kind="raster", data=grid)
167167
with contextlib.ExitStack() as stack:
168168
# shading using an xr.DataArray
169-
if "I" in kwargs and data_kind(kwargs["I"]) == "grid":
169+
if kwargs.get("I") is not None and data_kind(kwargs["I"]) == "grid":
170170
shading_context = lib.virtualfile_from_grid(kwargs["I"])
171171
kwargs["I"] = stack.enter_context(shading_context)
172172

pygmt/src/grdview.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,8 @@ def grdview(self, grid, **kwargs):
126126
file_context = lib.virtualfile_from_data(check_kind="raster", data=grid)
127127

128128
with contextlib.ExitStack() as stack:
129-
if "G" in kwargs: # deal with kwargs["G"] if drapegrid is xr.DataArray
129+
if kwargs.get("G") is not None:
130+
# deal with kwargs["G"] if drapegrid is xr.DataArray
130131
drapegrid = kwargs["G"]
131132
if data_kind(drapegrid) in ("file", "grid"):
132133
if data_kind(drapegrid) == "grid":

pygmt/src/makecpt.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -147,11 +147,11 @@ def makecpt(**kwargs):
147147
``categorical=True``.
148148
"""
149149
with Session() as lib:
150-
if "W" in kwargs and "Ww" in kwargs:
150+
if kwargs.get("W") is not None and kwargs.get("Ww") is not None:
151151
raise GMTInvalidInput("Set only categorical or cyclic to True, not both.")
152-
if "H" not in kwargs: # if no output is set
152+
if kwargs.get("H") is None: # if no output is set
153153
arg_str = build_arg_string(kwargs)
154-
elif "H" in kwargs: # if output is set
154+
else: # if output is set
155155
outfile, kwargs["H"] = kwargs.pop("H"), True
156156
if not outfile or not isinstance(outfile, str):
157157
raise GMTInvalidInput("'output' should be a proper file name.")

pygmt/src/plot.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -218,15 +218,15 @@ def plot(self, data=None, x=None, y=None, size=None, direction=None, **kwargs):
218218
kind = data_kind(data, x, y)
219219

220220
extra_arrays = []
221-
if "S" in kwargs and kwargs["S"][0] in "vV" and direction is not None:
221+
if kwargs.get("S") is not None and kwargs["S"][0] in "vV" and direction is not None:
222222
extra_arrays.extend(direction)
223223
elif (
224-
"S" not in kwargs
224+
kwargs.get("S") is None
225225
and kind == "geojson"
226226
and data.geom_type.isin(["Point", "MultiPoint"]).all()
227227
): # checking if the geometry of a geoDataFrame is Point or MultiPoint
228228
kwargs["S"] = "s0.2c"
229-
elif "S" not in kwargs and kind == "file" and data.endswith(".gmt"):
229+
elif kwargs.get("S") is None and kind == "file" and data.endswith(".gmt"):
230230
# checking that the data is a file path to set default style
231231
try:
232232
with open(which(data), mode="r", encoding="utf8") as file:
@@ -236,7 +236,7 @@ def plot(self, data=None, x=None, y=None, size=None, direction=None, **kwargs):
236236
kwargs["S"] = "s0.2c"
237237
except FileNotFoundError:
238238
pass
239-
if "G" in kwargs and is_nonstr_iter(kwargs["G"]):
239+
if kwargs.get("G") is not None and is_nonstr_iter(kwargs["G"]):
240240
if kind != "vectors":
241241
raise GMTInvalidInput(
242242
"Can't use arrays for color if data is matrix or file."
@@ -251,7 +251,7 @@ def plot(self, data=None, x=None, y=None, size=None, direction=None, **kwargs):
251251
extra_arrays.append(size)
252252

253253
for flag in ["I", "t"]:
254-
if flag in kwargs and is_nonstr_iter(kwargs[flag]):
254+
if kwargs.get(flag) is not None and is_nonstr_iter(kwargs[flag]):
255255
if kind != "vectors":
256256
raise GMTInvalidInput(
257257
f"Can't use arrays for {plot.aliases[flag]} if data is matrix or file."

pygmt/src/plot3d.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -188,15 +188,15 @@ def plot3d(
188188
kind = data_kind(data, x, y, z)
189189

190190
extra_arrays = []
191-
if "S" in kwargs and kwargs["S"][0] in "vV" and direction is not None:
191+
if kwargs.get("S") is not None and kwargs["S"][0] in "vV" and direction is not None:
192192
extra_arrays.extend(direction)
193193
elif (
194-
"S" not in kwargs
194+
kwargs.get("S") is None
195195
and kind == "geojson"
196196
and data.geom_type.isin(["Point", "MultiPoint"]).all()
197197
): # checking if the geometry of a geoDataFrame is Point or MultiPoint
198198
kwargs["S"] = "u0.2c"
199-
elif "S" not in kwargs and kind == "file" and data.endswith(".gmt"):
199+
elif kwargs.get("S") is None and kind == "file" and data.endswith(".gmt"):
200200
# checking that the data is a file path to set default style
201201
try:
202202
with open(which(data), mode="r", encoding="utf8") as file:
@@ -206,7 +206,7 @@ def plot3d(
206206
kwargs["S"] = "u0.2c"
207207
except FileNotFoundError:
208208
pass
209-
if "G" in kwargs and is_nonstr_iter(kwargs["G"]):
209+
if kwargs.get("G") is not None and is_nonstr_iter(kwargs["G"]):
210210
if kind != "vectors":
211211
raise GMTInvalidInput(
212212
"Can't use arrays for color if data is matrix or file."
@@ -221,7 +221,7 @@ def plot3d(
221221
extra_arrays.append(size)
222222

223223
for flag in ["I", "t"]:
224-
if flag in kwargs and is_nonstr_iter(kwargs[flag]):
224+
if kwargs.get(flag) is not None and is_nonstr_iter(kwargs[flag]):
225225
if kind != "vectors":
226226
raise GMTInvalidInput(
227227
f"Can't use arrays for {plot3d.aliases[flag]} if data is matrix or file."

pygmt/src/project.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -210,13 +210,13 @@ def project(data=None, x=None, y=None, z=None, outfile=None, **kwargs):
210210
by ``outfile``)
211211
"""
212212

213-
if "C" not in kwargs:
213+
if kwargs.get("C") is None:
214214
raise GMTInvalidInput("The `center` parameter must be specified.")
215-
if "G" not in kwargs and data is None:
215+
if kwargs.get("G") is None and data is None:
216216
raise GMTInvalidInput(
217217
"The `data` parameter must be specified unless `generate` is used."
218218
)
219-
if "G" in kwargs and "F" in kwargs:
219+
if kwargs.get("G") is not None and kwargs.get("F") is not None:
220220
raise GMTInvalidInput(
221221
"The `convention` parameter is not allowed with `generate`."
222222
)
@@ -225,7 +225,7 @@ def project(data=None, x=None, y=None, z=None, outfile=None, **kwargs):
225225
if outfile is None: # Output to tmpfile if outfile is not set
226226
outfile = tmpfile.name
227227
with Session() as lib:
228-
if "G" not in kwargs:
228+
if kwargs.get("G") is None:
229229
# Choose how data will be passed into the module
230230
table_context = lib.virtualfile_from_data(
231231
check_kind="vector", data=data, x=x, y=y, z=z, required_z=False
@@ -240,7 +240,7 @@ def project(data=None, x=None, y=None, z=None, outfile=None, **kwargs):
240240

241241
# if user did not set outfile, return pd.DataFrame
242242
if outfile == tmpfile.name:
243-
if "G" in kwargs:
243+
if kwargs.get("G") is not None:
244244
column_names = list("rsp")
245245
result = pd.read_csv(tmpfile.name, sep="\t", names=column_names)
246246
else:

pygmt/src/solar.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def solar(self, terminator="d", terminator_datetime=None, **kwargs):
6666
"""
6767

6868
kwargs = self._preprocess(**kwargs) # pylint: disable=protected-access
69-
if "T" in kwargs:
69+
if kwargs.get("T") is not None:
7070
raise GMTInvalidInput(
7171
"Use 'terminator' and 'terminator_datetime' instead of 'T'."
7272
)

0 commit comments

Comments
 (0)