33from functools import partial
44
55import jax .numpy as jnp
6- import matplotlib .pyplot as plt
76import numpy as np
87
98import brainpy .math as bm
109from brainpy import errors
1110from brainpy .analysis import stability , utils , constants as C
1211from brainpy .analysis .lowdim .lowdim_analyzer import *
1312
13+ pyplot = None
14+
1415__all__ = [
1516 'Bifurcation1D' ,
1617 'Bifurcation2D' ,
@@ -47,6 +48,8 @@ def F_vmap_dfxdx(self):
4748
4849 def plot_bifurcation (self , with_plot = True , show = False , with_return = False ,
4950 tol_aux = 1e-8 , loss_screen = None ):
51+ global pyplot
52+ if pyplot is None : from matplotlib import pyplot
5053 utils .output ('I am making bifurcation analysis ...' )
5154
5255 xs = self .resolutions [self .x_var ]
@@ -72,21 +75,21 @@ def plot_bifurcation(self, with_plot=True, show=False, with_return=False,
7275 container [fp_type ]['x' ].append (x )
7376
7477 # visualization
75- plt .figure (self .x_var )
78+ pyplot .figure (self .x_var )
7679 for fp_type , points in container .items ():
7780 if len (points ['x' ]):
7881 plot_style = stability .plot_scheme [fp_type ]
79- plt .plot (points ['p' ], points ['x' ], '.' , ** plot_style , label = fp_type )
80- plt .xlabel (self .target_par_names [0 ])
81- plt .ylabel (self .x_var )
82+ pyplot .plot (points ['p' ], points ['x' ], '.' , ** plot_style , label = fp_type )
83+ pyplot .xlabel (self .target_par_names [0 ])
84+ pyplot .ylabel (self .x_var )
8285
8386 scale = (self .lim_scale - 1 ) / 2
84- plt .xlim (* utils .rescale (self .target_pars [self .target_par_names [0 ]], scale = scale ))
85- plt .ylim (* utils .rescale (self .target_vars [self .x_var ], scale = scale ))
87+ pyplot .xlim (* utils .rescale (self .target_pars [self .target_par_names [0 ]], scale = scale ))
88+ pyplot .ylim (* utils .rescale (self .target_vars [self .x_var ], scale = scale ))
8689
87- plt .legend ()
90+ pyplot .legend ()
8891 if show :
89- plt .show ()
92+ pyplot .show ()
9093
9194 elif len (self .target_pars ) == 2 :
9295 container = {c : {'p0' : [], 'p1' : [], 'x' : []} for c in stability .get_1d_stability_types ()}
@@ -99,7 +102,7 @@ def plot_bifurcation(self, with_plot=True, show=False, with_return=False,
99102 container [fp_type ]['x' ].append (x )
100103
101104 # visualization
102- fig = plt .figure (self .x_var )
105+ fig = pyplot .figure (self .x_var )
103106 ax = fig .add_subplot (projection = '3d' )
104107 for fp_type , points in container .items ():
105108 if len (points ['x' ]):
@@ -121,7 +124,7 @@ def plot_bifurcation(self, with_plot=True, show=False, with_return=False,
121124 ax .grid (True )
122125 ax .legend ()
123126 if show :
124- plt .show ()
127+ pyplot .show ()
125128
126129 else :
127130 raise errors .BrainPyError (f'Cannot visualize co-dimension { len (self .target_pars )} '
@@ -212,6 +215,8 @@ def plot_bifurcation(self, with_plot=True, show=False, with_return=False,
212215 - parameters: a 2D matrix with the shape of (num_point, num_par)
213216 - jacobians: a 3D tensors with the shape of (num_point, 2, 2)
214217 """
218+ global pyplot
219+ if pyplot is None : from matplotlib import pyplot
215220 utils .output ('I am making bifurcation analysis ...' )
216221
217222 if self ._can_convert_to_one_eq ():
@@ -289,21 +294,21 @@ def plot_bifurcation(self, with_plot=True, show=False, with_return=False,
289294
290295 # visualization
291296 for var in self .target_var_names :
292- plt .figure (var )
297+ pyplot .figure (var )
293298 for fp_type , points in container .items ():
294299 if len (points ['p' ]):
295300 plot_style = stability .plot_scheme [fp_type ]
296- plt .plot (points ['p' ], points [var ], '.' , ** plot_style , label = fp_type )
297- plt .xlabel (self .target_par_names [0 ])
298- plt .ylabel (var )
301+ pyplot .plot (points ['p' ], points [var ], '.' , ** plot_style , label = fp_type )
302+ pyplot .xlabel (self .target_par_names [0 ])
303+ pyplot .ylabel (var )
299304
300305 scale = (self .lim_scale - 1 ) / 2
301- plt .xlim (* utils .rescale (self .target_pars [self .target_par_names [0 ]], scale = scale ))
302- plt .ylim (* utils .rescale (self .target_vars [var ], scale = scale ))
306+ pyplot .xlim (* utils .rescale (self .target_pars [self .target_par_names [0 ]], scale = scale ))
307+ pyplot .ylim (* utils .rescale (self .target_vars [var ], scale = scale ))
303308
304- plt .legend ()
309+ pyplot .legend ()
305310 if show :
306- plt .show ()
311+ pyplot .show ()
307312
308313 # bifurcation analysis of co-dimension 2
309314 elif len (self .target_pars ) == 2 :
@@ -320,7 +325,7 @@ def plot_bifurcation(self, with_plot=True, show=False, with_return=False,
320325
321326 # visualization
322327 for var in self .target_var_names :
323- fig = plt .figure (var )
328+ fig = pyplot .figure (var )
324329 ax = fig .add_subplot (projection = '3d' )
325330 for fp_type , points in container .items ():
326331 if len (points ['p0' ]):
@@ -340,7 +345,7 @@ def plot_bifurcation(self, with_plot=True, show=False, with_return=False,
340345 ax .grid (True )
341346 ax .legend ()
342347 if show :
343- plt .show ()
348+ pyplot .show ()
344349
345350 else :
346351 raise ValueError ('Unknown length of parameters.' )
@@ -350,6 +355,8 @@ def plot_bifurcation(self, with_plot=True, show=False, with_return=False,
350355
351356 def plot_limit_cycle_by_sim (self , duration = 100 , with_plot = True , with_return = False ,
352357 plot_style = None , tol = 0.001 , show = False , dt = None , offset = 1. ):
358+ global pyplot
359+ if pyplot is None : from matplotlib import pyplot
353360 utils .output ('I am plotting the limit cycle ...' )
354361 if self ._fixed_points is None :
355362 utils .output ('No fixed points found, you may call "plot_bifurcation(with_plot=True)" first.' )
@@ -390,27 +397,27 @@ def plot_limit_cycle_by_sim(self, duration=100, with_plot=True, with_return=Fals
390397
391398 if len (self .target_par_names ) == 2 :
392399 for i , var in enumerate (self .target_var_names ):
393- plt .figure (var )
394- plt .plot (ps_limit_cycle [0 ], ps_limit_cycle [1 ], vs_limit_cycle [i ]['max' ],
395- ** plot_style , label = 'limit cycle (max)' )
396- plt .plot (ps_limit_cycle [0 ], ps_limit_cycle [1 ], vs_limit_cycle [i ]['min' ],
397- ** plot_style , label = 'limit cycle (min)' )
398- plt .legend ()
400+ pyplot .figure (var )
401+ pyplot .plot (ps_limit_cycle [0 ], ps_limit_cycle [1 ], vs_limit_cycle [i ]['max' ],
402+ ** plot_style , label = 'limit cycle (max)' )
403+ pyplot .plot (ps_limit_cycle [0 ], ps_limit_cycle [1 ], vs_limit_cycle [i ]['min' ],
404+ ** plot_style , label = 'limit cycle (min)' )
405+ pyplot .legend ()
399406
400407 elif len (self .target_par_names ) == 1 :
401408 for i , var in enumerate (self .target_var_names ):
402- plt .figure (var )
403- plt .plot (ps_limit_cycle [0 ], vs_limit_cycle [i ]['max' ], fmt ,
404- ** plot_style , label = 'limit cycle (max)' )
405- plt .plot (ps_limit_cycle [0 ], vs_limit_cycle [i ]['min' ], fmt ,
406- ** plot_style , label = 'limit cycle (min)' )
407- plt .legend ()
409+ pyplot .figure (var )
410+ pyplot .plot (ps_limit_cycle [0 ], vs_limit_cycle [i ]['max' ], fmt ,
411+ ** plot_style , label = 'limit cycle (max)' )
412+ pyplot .plot (ps_limit_cycle [0 ], vs_limit_cycle [i ]['min' ], fmt ,
413+ ** plot_style , label = 'limit cycle (min)' )
414+ pyplot .legend ()
408415
409416 else :
410417 raise errors .AnalyzerError
411418
412419 if show :
413- plt .show ()
420+ pyplot .show ()
414421
415422 if with_return :
416423 return vs_limit_cycle , ps_limit_cycle
@@ -437,6 +444,8 @@ def __init__(self, model, fast_vars, slow_vars, fixed_vars=None,
437444
438445 def plot_trajectory (self , initials , duration , plot_durations = None ,
439446 dt = None , show = False , with_plot = True , with_return = False ):
447+ global pyplot
448+ if pyplot is None : from matplotlib import pyplot
440449 utils .output ('I am plotting the trajectory ...' )
441450
442451 # check the initial values
@@ -470,14 +479,14 @@ def plot_trajectory(self, initials, duration, plot_durations=None,
470479 end = int (plot_durations [i ][1 ] / dt )
471480 p1_var = self .target_par_names [0 ]
472481 if len (self .target_par_names ) == 1 :
473- lines = plt .plot (mon_res [self .x_var ][start : end , i ],
474- mon_res [p1_var ][start : end , i ], label = legend )
482+ lines = pyplot .plot (mon_res [self .x_var ][start : end , i ],
483+ mon_res [p1_var ][start : end , i ], label = legend )
475484 elif len (self .target_par_names ) == 2 :
476485 p2_var = self .target_par_names [1 ]
477- lines = plt .plot (mon_res [self .x_var ][start : end , i ],
478- mon_res [p1_var ][start : end , i ],
479- mon_res [p2_var ][start : end , i ],
480- label = legend )
486+ lines = pyplot .plot (mon_res [self .x_var ][start : end , i ],
487+ mon_res [p1_var ][start : end , i ],
488+ mon_res [p2_var ][start : end , i ],
489+ label = legend )
481490 else :
482491 raise ValueError
483492 utils .add_arrow (lines [0 ])
@@ -488,10 +497,10 @@ def plot_trajectory(self, initials, duration, plot_durations=None,
488497 # scale = (self.lim_scale - 1.) / 2
489498 # plt.xlim(*utils.rescale(self.target_vars[self.x_var], scale=scale))
490499 # plt.ylim(*utils.rescale(self.target_vars[self.target_par_names[0]], scale=scale))
491- plt .legend ()
500+ pyplot .legend ()
492501
493502 if show :
494- plt .show ()
503+ pyplot .show ()
495504
496505 if with_return :
497506 return mon_res
@@ -517,6 +526,8 @@ def __init__(self, model, fast_vars, slow_vars, fixed_vars=None,
517526
518527 def plot_trajectory (self , initials , duration , plot_durations = None ,
519528 dt = None , show = False , with_plot = True , with_return = False ):
529+ global pyplot
530+ if pyplot is None : from matplotlib import pyplot
520531 utils .output ('I am plotting the trajectory ...' )
521532
522533 # check the initial values
@@ -548,25 +559,25 @@ def plot_trajectory(self, initials, duration, plot_durations=None,
548559 end = int (plot_durations [i ][1 ] / dt )
549560
550561 # visualization
551- plt .figure (self .x_var )
552- lines = plt .plot (mon_res [self .target_par_names [0 ]][start : end , i ],
553- mon_res [self .x_var ][start : end , i ],
554- label = legend )
562+ pyplot .figure (self .x_var )
563+ lines = pyplot .plot (mon_res [self .target_par_names [0 ]][start : end , i ],
564+ mon_res [self .x_var ][start : end , i ],
565+ label = legend )
555566 utils .add_arrow (lines [0 ])
556567
557- plt .figure (self .y_var )
558- lines = plt .plot (mon_res [self .target_par_names [0 ]][start : end , i ],
559- mon_res [self .y_var ][start : end , i ],
560- label = legend )
568+ pyplot .figure (self .y_var )
569+ lines = pyplot .plot (mon_res [self .target_par_names [0 ]][start : end , i ],
570+ mon_res [self .y_var ][start : end , i ],
571+ label = legend )
561572 utils .add_arrow (lines [0 ])
562573
563- plt .figure (self .x_var )
564- plt .legend ()
565- plt .figure (self .y_var )
566- plt .legend ()
574+ pyplot .figure (self .x_var )
575+ pyplot .legend ()
576+ pyplot .figure (self .y_var )
577+ pyplot .legend ()
567578
568579 if show :
569- plt .show ()
580+ pyplot .show ()
570581
571582 if with_return :
572583 return mon_res
0 commit comments