1414import matplotlib .style
1515import matplotlib .units
1616import matplotlib .testing
17- from matplotlib import cbook
18- from matplotlib import ft2font
19- from matplotlib import pyplot as plt
20- from matplotlib import ticker
17+ from matplotlib import cbook , ft2font , pyplot as plt , ticker , _pylab_helpers
2118from .compare import comparable_formats , compare_images , make_test_filename
2219from .exceptions import ImageComparisonFailure
2320
@@ -129,6 +126,29 @@ def remove_ticks(ax):
129126 remove_ticks (ax )
130127
131128
129+ @contextlib .contextmanager
130+ def _collect_new_figures ():
131+ """
132+ After::
133+
134+ with _collect_new_figures() as figs:
135+ some_code()
136+
137+ the list *figs* contains the figures that have been created during the
138+ execution of ``some_code``, sorted by figure number.
139+ """
140+ managers = _pylab_helpers .Gcf .figs
141+ preexisting = [manager for manager in managers .values ()]
142+ new_figs = []
143+ try :
144+ yield new_figs
145+ finally :
146+ new_managers = sorted ([manager for manager in managers .values ()
147+ if manager not in preexisting ],
148+ key = lambda manager : manager .num )
149+ new_figs [:] = [manager .canvas .figure for manager in new_managers ]
150+
151+
132152def _raise_on_image_difference (expected , actual , tol ):
133153 __tracebackhide__ = True
134154
@@ -178,10 +198,8 @@ def copy_baseline(self, baseline, extension):
178198 f"{ orig_expected_path } " ) from err
179199 return expected_fname
180200
181- def compare (self , idx , baseline , extension , * , _lock = False ):
201+ def compare (self , fig , baseline , extension , * , _lock = False ):
182202 __tracebackhide__ = True
183- fignum = plt .get_fignums ()[idx ]
184- fig = plt .figure (fignum )
185203
186204 if self .remove_text :
187205 remove_ticks_and_titles (fig )
@@ -196,7 +214,12 @@ def compare(self, idx, baseline, extension, *, _lock=False):
196214 lock = (cbook ._lock_path (actual_path )
197215 if _lock else contextlib .nullcontext ())
198216 with lock :
199- fig .savefig (actual_path , ** kwargs )
217+ try :
218+ fig .savefig (actual_path , ** kwargs )
219+ finally :
220+ # Matplotlib has an autouse fixture to close figures, but this
221+ # makes things more convenient for third-party users.
222+ plt .close (fig )
200223 expected_path = self .copy_baseline (baseline , extension )
201224 _raise_on_image_difference (expected_path , actual_path , self .tol )
202225
@@ -235,7 +258,9 @@ def wrapper(*args, extension, request, **kwargs):
235258 img = _ImageComparisonBase (func , tol = tol , remove_text = remove_text ,
236259 savefig_kwargs = savefig_kwargs )
237260 matplotlib .testing .set_font_settings_for_testing ()
238- func (* args , ** kwargs )
261+
262+ with _collect_new_figures () as figs :
263+ func (* args , ** kwargs )
239264
240265 # If the test is parametrized in any way other than applied via
241266 # this decorator, then we need to use a lock to prevent two
@@ -252,11 +277,11 @@ def wrapper(*args, extension, request, **kwargs):
252277 our_baseline_images = request .getfixturevalue (
253278 'baseline_images' )
254279
255- assert len (plt . get_fignums () ) == len (our_baseline_images ), (
280+ assert len (figs ) == len (our_baseline_images ), (
256281 "Test generated {} images but there are {} baseline images"
257- .format (len (plt . get_fignums () ), len (our_baseline_images )))
258- for idx , baseline in enumerate ( our_baseline_images ):
259- img .compare (idx , baseline , extension , _lock = needs_lock )
282+ .format (len (figs ), len (our_baseline_images )))
283+ for fig , baseline in zip ( figs , our_baseline_images ):
284+ img .compare (fig , baseline , extension , _lock = needs_lock )
260285
261286 parameters = list (old_sig .parameters .values ())
262287 if 'extension' not in old_sig .parameters :
@@ -427,11 +452,9 @@ def wrapper(*args, ext, request, **kwargs):
427452 try :
428453 fig_test = plt .figure ("test" )
429454 fig_ref = plt .figure ("reference" )
430- # Keep track of number of open figures, to make sure test
431- # doesn't create any new ones
432- n_figs = len (plt .get_fignums ())
433- func (* args , fig_test = fig_test , fig_ref = fig_ref , ** kwargs )
434- if len (plt .get_fignums ()) > n_figs :
455+ with _collect_new_figures () as figs :
456+ func (* args , fig_test = fig_test , fig_ref = fig_ref , ** kwargs )
457+ if figs :
435458 raise RuntimeError ('Number of open figures changed during '
436459 'test. Make sure you are plotting to '
437460 'fig_test or fig_ref, or if this is '
0 commit comments