Skip to content

Commit 5bf4d18

Browse files
committed
Fix another mock problem.
1 parent a41e84d commit 5bf4d18

File tree

1 file changed

+24
-16
lines changed

1 file changed

+24
-16
lines changed

lib/iris/tests/unit/fileformats/netcdf/saver/test_Saver.py

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -261,9 +261,6 @@ def test_compression(self):
261261
)
262262
cube.add_ancillary_variable(anc_coord, data_dims=data_dims)
263263

264-
patch = self.patch(
265-
"iris.fileformats.netcdf.saver._thread_safe_nc.DatasetWrapper.createVariable"
266-
)
267264
compression_kwargs = {
268265
"complevel": 9,
269266
"fletcher32": True,
@@ -273,10 +270,16 @@ def test_compression(self):
273270

274271
with self.temp_filename(suffix=".nc") as nc_path:
275272
with Saver(nc_path, "NETCDF4", compute=False) as saver:
273+
createvar_spy = self.patch(
274+
"iris.fileformats.netcdf.saver._thread_safe_nc.DatasetWrapper.createVariable",
275+
# Use 'wraps' to allow the patched methods to function as normal
276+
# - the patch object just acts as a 'spy' on its calls.
277+
wraps=saver._dataset.createVariable,
278+
)
276279
saver.write(cube, **compression_kwargs)
277280

278-
self.assertEqual(5, patch.call_count)
279-
result = self._filter_compression_calls(patch, compression_kwargs)
281+
self.assertEqual(5, createvar_spy.call_count)
282+
result = self._filter_compression_calls(createvar_spy, compression_kwargs)
280283
self.assertEqual(3, len(result))
281284
self.assertEqual({cube.name(), aux_coord.name(), anc_coord.name()}, set(result))
282285

@@ -294,9 +297,6 @@ def test_non_compression__shape(self):
294297
)
295298
cube.add_ancillary_variable(anc_coord, data_dims=data_dims[1])
296299

297-
patch = self.patch(
298-
"iris.fileformats.netcdf.saver._thread_safe_nc.DatasetWrapper.createVariable"
299-
)
300300
compression_kwargs = {
301301
"complevel": 9,
302302
"fletcher32": True,
@@ -306,11 +306,17 @@ def test_non_compression__shape(self):
306306

307307
with self.temp_filename(suffix=".nc") as nc_path:
308308
with Saver(nc_path, "NETCDF4", compute=False) as saver:
309+
createvar_spy = self.patch(
310+
"iris.fileformats.netcdf.saver._thread_safe_nc.DatasetWrapper.createVariable",
311+
# Use 'wraps' to allow the patched methods to function as normal
312+
# - the patch object just acts as a 'spy' on its calls.
313+
wraps=saver._dataset.createVariable,
314+
)
309315
saver.write(cube, **compression_kwargs)
310316

311-
self.assertEqual(5, patch.call_count)
317+
self.assertEqual(5, createvar_spy.call_count)
312318
result = self._filter_compression_calls(
313-
patch, compression_kwargs, mismatch=True
319+
createvar_spy, compression_kwargs, mismatch=True
314320
)
315321
self.assertEqual(4, len(result))
316322
# the aux coord and ancil variable are not compressed due to shape, and
@@ -327,10 +333,6 @@ def test_non_compression__dtype(self):
327333
aux_coord = AuxCoord(data, var_name="non_compress_aux", units="1")
328334
cube.add_aux_coord(aux_coord, data_dims=data_dims)
329335

330-
patch = self.patch(
331-
"iris.fileformats.netcdf.saver._thread_safe_nc.DatasetWrapper.createVariable"
332-
)
333-
patch.return_value = mock.MagicMock(dtype=np.dtype("S1"))
334336
compression_kwargs = {
335337
"complevel": 9,
336338
"fletcher32": True,
@@ -340,11 +342,17 @@ def test_non_compression__dtype(self):
340342

341343
with self.temp_filename(suffix=".nc") as nc_path:
342344
with Saver(nc_path, "NETCDF4", compute=False) as saver:
345+
createvar_spy = self.patch(
346+
"iris.fileformats.netcdf.saver._thread_safe_nc.DatasetWrapper.createVariable",
347+
# Use 'wraps' to allow the patched methods to function as normal
348+
# - the patch object just acts as a 'spy' on its calls.
349+
wraps=saver._dataset.createVariable,
350+
)
343351
saver.write(cube, **compression_kwargs)
344352

345-
self.assertEqual(4, patch.call_count)
353+
self.assertEqual(4, createvar_spy.call_count)
346354
result = self._filter_compression_calls(
347-
patch, compression_kwargs, mismatch=True
355+
createvar_spy, compression_kwargs, mismatch=True
348356
)
349357
self.assertEqual(3, len(result))
350358
# the aux coord is not compressed due to its string dtype, and

0 commit comments

Comments
 (0)