77
88import keras .src .callbacks
99
10- from matplotlib .colors import Normalize
11- from ...utils .plot_utils import make_figure , add_titles_and_labels , gradient_line , gradient_legend
10+ from ...utils .plot_utils import make_figure , add_titles_and_labels , gradient_line
1211
1312
1413def loss (
1514 history : keras .callbacks .History ,
1615 train_key : str = "loss" ,
1716 val_key : str = "val_loss" ,
1817 moving_average : bool = True ,
19- per_training_step : bool = False ,
2018 moving_average_span : int = 10 ,
2119 figsize : Sequence [float ] = None ,
2220 train_color : str = "#132a70" ,
2321 val_color : str = None ,
24- val_colormap : str = ' viridis' ,
22+ val_colormap : str = " viridis" ,
2523 lw_train : float = 2.0 ,
2624 lw_val : float = 3.0 ,
2725 val_marker_type : str = "o" ,
@@ -45,22 +43,28 @@ def loss(
4543 The validation loss key to look for in the history
4644 moving_average : bool, optional, default: False
4745 A flag for adding an exponential moving average line of the train_losses.
48- per_training_step : bool, optional, default: False
49- A flag for making loss trajectory detailed (to training steps) rather than per epoch.
50- ma_window_fraction : int, optional, default: 0.01
46+ moving_average_span : int, optional, default: 0.01
5147 Window size for the moving average as a fraction of total
5248 training steps.
5349 figsize : tuple or None, optional, default: None
5450 The figure size passed to the ``matplotlib`` constructor.
5551 Inferred if ``None``
5652 train_color : str, optional, default: '#8f2727'
5753 The color for the train loss trajectory
58- val_color : str, optional, default: black
54+ val_color : str, optional, default: None
55+ The color for the optional validation loss trajectory
56+ val_colormap : str, optional, default: "viridis"
5957 The color for the optional validation loss trajectory
6058 lw_train : int, optional, default: 2
6159 The linewidth for the training loss curve
6260 lw_val : int, optional, default: 3
6361 The linewidth for the validation loss curve
62+ val_marker_type : str, optional, default: o
63+ The marker type for the validation loss curve
64+ val_marker_size : int, optional, default: 34
65+ The marker size for the validation loss curve
66+ grid_alpha : float, optional, default: 0.2
67+ The transparency of the background grid
6468 legend_fontsize : int, optional, default: 14
6569 The font size of the legend text
6670 label_fontsize : int, optional, default: 14
@@ -111,41 +115,32 @@ def loss(
111115
112116 # Plot optional val curve
113117 if val_losses is not None :
114- if val_color is not None :
115- ax .plot (
116- val_step_index ,
117- val_losses .iloc [:, 0 ],
118- linestyle = "--" ,
119- marker = val_marker_type ,
120- color = val_color ,
121- lw = lw_val ,
122- label = "Validation" ,
123- )
124- else :
125- # Create line segments between each epoch
126- points = np .array ([val_step_index , val_losses .iloc [:,0 ]]).T .reshape (- 1 , 1 , 2 )
127- segments = np .concatenate ([points [:- 1 ], points [1 :]], axis = 1 )
128-
129- # Normalize color based on loss values
130- lc = gradient_line (
131- val_step_index ,
132- val_losses .iloc [:,0 ],
133- c = val_step_index ,
134- cmap = val_colormap ,
135- lw = lw_val ,
136- ax = ax
137- )
138- scatter = ax .scatter (
139- val_step_index ,
140- val_losses .iloc [:,0 ],
141- c = val_step_index ,
142- cmap = val_colormap ,
143- marker = val_marker_type ,
144- s = val_marker_size ,
145- zorder = 10 ,
146- edgecolors = 'none' ,
147- label = 'Validation'
148- )
118+ if val_color is not None :
119+ ax .plot (
120+ val_step_index ,
121+ val_losses .iloc [:, 0 ],
122+ linestyle = "--" ,
123+ marker = val_marker_type ,
124+ color = val_color ,
125+ lw = lw_val ,
126+ label = "Validation" ,
127+ )
128+ else :
129+ # Make gradient lines
130+ gradient_line (
131+ val_step_index , val_losses .iloc [:, 0 ], c = val_step_index , cmap = val_colormap , lw = lw_val , ax = ax
132+ )
133+ ax .scatter (
134+ val_step_index ,
135+ val_losses .iloc [:, 0 ],
136+ c = val_step_index ,
137+ cmap = val_colormap ,
138+ marker = val_marker_type ,
139+ s = val_marker_size ,
140+ zorder = 10 ,
141+ edgecolors = "none" ,
142+ label = "Validation" ,
143+ )
149144
150145 sns .despine (ax = ax )
151146 ax .grid (alpha = grid_alpha )
@@ -160,7 +155,7 @@ def loss(
160155 num_row = num_row ,
161156 num_col = 1 ,
162157 title = ["Loss Trajectory" ],
163- xlabel = "Training step #" if per_training_step else "Training epoch #" ,
158+ xlabel = "Training epoch #" ,
164159 ylabel = "Value" ,
165160 title_fontsize = title_fontsize ,
166161 label_fontsize = label_fontsize ,
0 commit comments