1111
1212import numpy as np
1313import matplotlib .pyplot as plt
14+ import matplotlib .colors as mcolors
1415
1516# Fixing random state for reproducibility
1617np .random .seed (19680801 )
@@ -26,15 +27,19 @@ def plot_scatter(ax, prng, nb_samples=100):
2627 return ax
2728
2829
29- def plot_colored_sinusoidal_lines (ax ):
30- """Plot sinusoidal lines with colors following the style color cycle."""
31- L = 2 * np .pi
32- x = np .linspace (0 , L )
30+ def plot_colored_lines (ax ):
31+ """Plot lines with colors following the style color cycle."""
32+ t = np .linspace (- 10 , 10 , 100 )
33+
34+ def sigmoid (t , t0 ):
35+ return 1 / (1 + np .exp (- (t - t0 )))
36+
3337 nb_colors = len (plt .rcParams ['axes.prop_cycle' ])
34- shift = np .linspace (0 , L , nb_colors , endpoint = False )
35- for s in shift :
36- ax .plot (x , np .sin (x + s ), '-' )
37- ax .set_xlim ([x [0 ], x [- 1 ]])
38+ shifts = np .linspace (- 5 , 5 , nb_colors )
39+ amplitudes = np .linspace (1 , 1.5 , nb_colors )
40+ for t0 , a in zip (shifts , amplitudes ):
41+ ax .plot (t , a * sigmoid (t , t0 ), '-' )
42+ ax .set_xlim (- 10 , 10 )
3843 return ax
3944
4045
@@ -108,23 +113,30 @@ def plot_figure(style_label=""):
108113 # double the width and halve the height. NB: use relative changes because
109114 # some styles may have a figure size different from the default one.
110115 (fig_width , fig_height ) = plt .rcParams ['figure.figsize' ]
111- fig_size = [fig_width * 2 , fig_height / 2 ]
116+ fig_size = [fig_width * 2 , fig_height / 1.75 ]
112117
113118 fig , axs = plt .subplots (ncols = 6 , nrows = 1 , num = style_label ,
114- figsize = fig_size , squeeze = True )
115- axs [0 ].set_ylabel (style_label , fontsize = 13 , fontweight = 'bold' )
119+ figsize = fig_size , constrained_layout = True )
120+
121+ # make a suptitle, in the same style for all subfigures,
122+ # except those with dark backgrounds, which get a lighter
123+ # color:
124+ col = np .array ([19 , 6 , 84 ])/ 256
125+ back = mcolors .rgb_to_hsv (
126+ mcolors .to_rgb (plt .rcParams ['figure.facecolor' ]))[2 ]
127+ if back < 0.5 :
128+ col = [0.8 , 0.8 , 1 ]
129+ fig .suptitle (style_label , x = 0.01 , fontsize = 14 , ha = 'left' ,
130+ color = col , fontfamily = 'DejaVu Sans' ,
131+ fontweight = 'normal' )
116132
117133 plot_scatter (axs [0 ], prng )
118134 plot_image_and_patch (axs [1 ], prng )
119135 plot_bar_graphs (axs [2 ], prng )
120136 plot_colored_circles (axs [3 ], prng )
121- plot_colored_sinusoidal_lines (axs [4 ])
137+ plot_colored_lines (axs [4 ])
122138 plot_histograms (axs [5 ], prng )
123139
124- fig .tight_layout ()
125-
126- return fig
127-
128140
129141if __name__ == "__main__" :
130142
@@ -141,6 +153,6 @@ def plot_figure(style_label=""):
141153 for style_label in style_list :
142154 with plt .rc_context ({"figure.max_open_warning" : len (style_list )}):
143155 with plt .style .context (style_label ):
144- fig = plot_figure (style_label = style_label )
156+ plot_figure (style_label = style_label )
145157
146158 plt .show ()
0 commit comments