Skip to content

Commit 339d516

Browse files
authored
Merge pull request #110 from PKU-NIP-Lab/matplotlib_error
fix matplotlib dependency on "brainpy.analysis" module
2 parents 89dc9d6 + 07c2bec commit 339d516

File tree

4 files changed

+137
-106
lines changed

4 files changed

+137
-106
lines changed

brainpy/analysis/lowdim/lowdim_analyzer.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
from functools import partial
44

5-
import matplotlib.pyplot as plt
65
import numpy as np
76
from jax import numpy as jnp
87
from jax.scipy.optimize import minimize
@@ -12,6 +11,8 @@
1211
from brainpy.analysis import constants as C, utils
1312
from brainpy.base.collector import Collector
1413

14+
pyplot = None
15+
1516
__all__ = [
1617
'LowDimAnalyzer',
1718
'Num1DAnalyzer',
@@ -207,7 +208,10 @@ def __init__(self,
207208
self.analyzed_results = tools.DictPlus()
208209

209210
def show_figure(self):
210-
plt.show()
211+
global pyplot
212+
if pyplot is None:
213+
from matplotlib import pyplot
214+
pyplot.show()
211215

212216

213217
class Num1DAnalyzer(LowDimAnalyzer):

brainpy/analysis/lowdim/lowdim_bifurcation.py

Lines changed: 66 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,15 @@
33
from functools import partial
44

55
import jax.numpy as jnp
6-
import matplotlib.pyplot as plt
76
import numpy as np
87

98
import brainpy.math as bm
109
from brainpy import errors
1110
from brainpy.analysis import stability, utils, constants as C
1211
from brainpy.analysis.lowdim.lowdim_analyzer import *
1312

13+
pyplot = None
14+
1415
__all__ = [
1516
'Bifurcation1D',
1617
'Bifurcation2D',
@@ -47,6 +48,8 @@ def F_vmap_dfxdx(self):
4748

4849
def plot_bifurcation(self, with_plot=True, show=False, with_return=False,
4950
tol_aux=1e-8, loss_screen=None):
51+
global pyplot
52+
if pyplot is None: from matplotlib import pyplot
5053
utils.output('I am making bifurcation analysis ...')
5154

5255
xs = self.resolutions[self.x_var]
@@ -72,21 +75,21 @@ def plot_bifurcation(self, with_plot=True, show=False, with_return=False,
7275
container[fp_type]['x'].append(x)
7376

7477
# visualization
75-
plt.figure(self.x_var)
78+
pyplot.figure(self.x_var)
7679
for fp_type, points in container.items():
7780
if len(points['x']):
7881
plot_style = stability.plot_scheme[fp_type]
79-
plt.plot(points['p'], points['x'], '.', **plot_style, label=fp_type)
80-
plt.xlabel(self.target_par_names[0])
81-
plt.ylabel(self.x_var)
82+
pyplot.plot(points['p'], points['x'], '.', **plot_style, label=fp_type)
83+
pyplot.xlabel(self.target_par_names[0])
84+
pyplot.ylabel(self.x_var)
8285

8386
scale = (self.lim_scale - 1) / 2
84-
plt.xlim(*utils.rescale(self.target_pars[self.target_par_names[0]], scale=scale))
85-
plt.ylim(*utils.rescale(self.target_vars[self.x_var], scale=scale))
87+
pyplot.xlim(*utils.rescale(self.target_pars[self.target_par_names[0]], scale=scale))
88+
pyplot.ylim(*utils.rescale(self.target_vars[self.x_var], scale=scale))
8689

87-
plt.legend()
90+
pyplot.legend()
8891
if show:
89-
plt.show()
92+
pyplot.show()
9093

9194
elif len(self.target_pars) == 2:
9295
container = {c: {'p0': [], 'p1': [], 'x': []} for c in stability.get_1d_stability_types()}
@@ -99,7 +102,7 @@ def plot_bifurcation(self, with_plot=True, show=False, with_return=False,
99102
container[fp_type]['x'].append(x)
100103

101104
# visualization
102-
fig = plt.figure(self.x_var)
105+
fig = pyplot.figure(self.x_var)
103106
ax = fig.add_subplot(projection='3d')
104107
for fp_type, points in container.items():
105108
if len(points['x']):
@@ -121,7 +124,7 @@ def plot_bifurcation(self, with_plot=True, show=False, with_return=False,
121124
ax.grid(True)
122125
ax.legend()
123126
if show:
124-
plt.show()
127+
pyplot.show()
125128

126129
else:
127130
raise errors.BrainPyError(f'Cannot visualize co-dimension {len(self.target_pars)} '
@@ -212,6 +215,8 @@ def plot_bifurcation(self, with_plot=True, show=False, with_return=False,
212215
- parameters: a 2D matrix with the shape of (num_point, num_par)
213216
- jacobians: a 3D tensors with the shape of (num_point, 2, 2)
214217
"""
218+
global pyplot
219+
if pyplot is None: from matplotlib import pyplot
215220
utils.output('I am making bifurcation analysis ...')
216221

217222
if self._can_convert_to_one_eq():
@@ -289,21 +294,21 @@ def plot_bifurcation(self, with_plot=True, show=False, with_return=False,
289294

290295
# visualization
291296
for var in self.target_var_names:
292-
plt.figure(var)
297+
pyplot.figure(var)
293298
for fp_type, points in container.items():
294299
if len(points['p']):
295300
plot_style = stability.plot_scheme[fp_type]
296-
plt.plot(points['p'], points[var], '.', **plot_style, label=fp_type)
297-
plt.xlabel(self.target_par_names[0])
298-
plt.ylabel(var)
301+
pyplot.plot(points['p'], points[var], '.', **plot_style, label=fp_type)
302+
pyplot.xlabel(self.target_par_names[0])
303+
pyplot.ylabel(var)
299304

300305
scale = (self.lim_scale - 1) / 2
301-
plt.xlim(*utils.rescale(self.target_pars[self.target_par_names[0]], scale=scale))
302-
plt.ylim(*utils.rescale(self.target_vars[var], scale=scale))
306+
pyplot.xlim(*utils.rescale(self.target_pars[self.target_par_names[0]], scale=scale))
307+
pyplot.ylim(*utils.rescale(self.target_vars[var], scale=scale))
303308

304-
plt.legend()
309+
pyplot.legend()
305310
if show:
306-
plt.show()
311+
pyplot.show()
307312

308313
# bifurcation analysis of co-dimension 2
309314
elif len(self.target_pars) == 2:
@@ -320,7 +325,7 @@ def plot_bifurcation(self, with_plot=True, show=False, with_return=False,
320325

321326
# visualization
322327
for var in self.target_var_names:
323-
fig = plt.figure(var)
328+
fig = pyplot.figure(var)
324329
ax = fig.add_subplot(projection='3d')
325330
for fp_type, points in container.items():
326331
if len(points['p0']):
@@ -340,7 +345,7 @@ def plot_bifurcation(self, with_plot=True, show=False, with_return=False,
340345
ax.grid(True)
341346
ax.legend()
342347
if show:
343-
plt.show()
348+
pyplot.show()
344349

345350
else:
346351
raise ValueError('Unknown length of parameters.')
@@ -350,6 +355,8 @@ def plot_bifurcation(self, with_plot=True, show=False, with_return=False,
350355

351356
def plot_limit_cycle_by_sim(self, duration=100, with_plot=True, with_return=False,
352357
plot_style=None, tol=0.001, show=False, dt=None, offset=1.):
358+
global pyplot
359+
if pyplot is None: from matplotlib import pyplot
353360
utils.output('I am plotting the limit cycle ...')
354361
if self._fixed_points is None:
355362
utils.output('No fixed points found, you may call "plot_bifurcation(with_plot=True)" first.')
@@ -390,27 +397,27 @@ def plot_limit_cycle_by_sim(self, duration=100, with_plot=True, with_return=Fals
390397

391398
if len(self.target_par_names) == 2:
392399
for i, var in enumerate(self.target_var_names):
393-
plt.figure(var)
394-
plt.plot(ps_limit_cycle[0], ps_limit_cycle[1], vs_limit_cycle[i]['max'],
395-
**plot_style, label='limit cycle (max)')
396-
plt.plot(ps_limit_cycle[0], ps_limit_cycle[1], vs_limit_cycle[i]['min'],
397-
**plot_style, label='limit cycle (min)')
398-
plt.legend()
400+
pyplot.figure(var)
401+
pyplot.plot(ps_limit_cycle[0], ps_limit_cycle[1], vs_limit_cycle[i]['max'],
402+
**plot_style, label='limit cycle (max)')
403+
pyplot.plot(ps_limit_cycle[0], ps_limit_cycle[1], vs_limit_cycle[i]['min'],
404+
**plot_style, label='limit cycle (min)')
405+
pyplot.legend()
399406

400407
elif len(self.target_par_names) == 1:
401408
for i, var in enumerate(self.target_var_names):
402-
plt.figure(var)
403-
plt.plot(ps_limit_cycle[0], vs_limit_cycle[i]['max'], fmt,
404-
**plot_style, label='limit cycle (max)')
405-
plt.plot(ps_limit_cycle[0], vs_limit_cycle[i]['min'], fmt,
406-
**plot_style, label='limit cycle (min)')
407-
plt.legend()
409+
pyplot.figure(var)
410+
pyplot.plot(ps_limit_cycle[0], vs_limit_cycle[i]['max'], fmt,
411+
**plot_style, label='limit cycle (max)')
412+
pyplot.plot(ps_limit_cycle[0], vs_limit_cycle[i]['min'], fmt,
413+
**plot_style, label='limit cycle (min)')
414+
pyplot.legend()
408415

409416
else:
410417
raise errors.AnalyzerError
411418

412419
if show:
413-
plt.show()
420+
pyplot.show()
414421

415422
if with_return:
416423
return vs_limit_cycle, ps_limit_cycle
@@ -437,6 +444,8 @@ def __init__(self, model, fast_vars, slow_vars, fixed_vars=None,
437444

438445
def plot_trajectory(self, initials, duration, plot_durations=None,
439446
dt=None, show=False, with_plot=True, with_return=False):
447+
global pyplot
448+
if pyplot is None: from matplotlib import pyplot
440449
utils.output('I am plotting the trajectory ...')
441450

442451
# check the initial values
@@ -470,14 +479,14 @@ def plot_trajectory(self, initials, duration, plot_durations=None,
470479
end = int(plot_durations[i][1] / dt)
471480
p1_var = self.target_par_names[0]
472481
if len(self.target_par_names) == 1:
473-
lines = plt.plot(mon_res[self.x_var][start: end, i],
474-
mon_res[p1_var][start: end, i], label=legend)
482+
lines = pyplot.plot(mon_res[self.x_var][start: end, i],
483+
mon_res[p1_var][start: end, i], label=legend)
475484
elif len(self.target_par_names) == 2:
476485
p2_var = self.target_par_names[1]
477-
lines = plt.plot(mon_res[self.x_var][start: end, i],
478-
mon_res[p1_var][start: end, i],
479-
mon_res[p2_var][start: end, i],
480-
label=legend)
486+
lines = pyplot.plot(mon_res[self.x_var][start: end, i],
487+
mon_res[p1_var][start: end, i],
488+
mon_res[p2_var][start: end, i],
489+
label=legend)
481490
else:
482491
raise ValueError
483492
utils.add_arrow(lines[0])
@@ -488,10 +497,10 @@ def plot_trajectory(self, initials, duration, plot_durations=None,
488497
# scale = (self.lim_scale - 1.) / 2
489498
# plt.xlim(*utils.rescale(self.target_vars[self.x_var], scale=scale))
490499
# plt.ylim(*utils.rescale(self.target_vars[self.target_par_names[0]], scale=scale))
491-
plt.legend()
500+
pyplot.legend()
492501

493502
if show:
494-
plt.show()
503+
pyplot.show()
495504

496505
if with_return:
497506
return mon_res
@@ -517,6 +526,8 @@ def __init__(self, model, fast_vars, slow_vars, fixed_vars=None,
517526

518527
def plot_trajectory(self, initials, duration, plot_durations=None,
519528
dt=None, show=False, with_plot=True, with_return=False):
529+
global pyplot
530+
if pyplot is None: from matplotlib import pyplot
520531
utils.output('I am plotting the trajectory ...')
521532

522533
# check the initial values
@@ -548,25 +559,25 @@ def plot_trajectory(self, initials, duration, plot_durations=None,
548559
end = int(plot_durations[i][1] / dt)
549560

550561
# visualization
551-
plt.figure(self.x_var)
552-
lines = plt.plot(mon_res[self.target_par_names[0]][start: end, i],
553-
mon_res[self.x_var][start: end, i],
554-
label=legend)
562+
pyplot.figure(self.x_var)
563+
lines = pyplot.plot(mon_res[self.target_par_names[0]][start: end, i],
564+
mon_res[self.x_var][start: end, i],
565+
label=legend)
555566
utils.add_arrow(lines[0])
556567

557-
plt.figure(self.y_var)
558-
lines = plt.plot(mon_res[self.target_par_names[0]][start: end, i],
559-
mon_res[self.y_var][start: end, i],
560-
label=legend)
568+
pyplot.figure(self.y_var)
569+
lines = pyplot.plot(mon_res[self.target_par_names[0]][start: end, i],
570+
mon_res[self.y_var][start: end, i],
571+
label=legend)
561572
utils.add_arrow(lines[0])
562573

563-
plt.figure(self.x_var)
564-
plt.legend()
565-
plt.figure(self.y_var)
566-
plt.legend()
574+
pyplot.figure(self.x_var)
575+
pyplot.legend()
576+
pyplot.figure(self.y_var)
577+
pyplot.legend()
567578

568579
if show:
569-
plt.show()
580+
pyplot.show()
570581

571582
if with_return:
572583
return mon_res

0 commit comments

Comments
 (0)