77
88import keras .src .callbacks
99
10- from ...utils .plot_utils import make_figure , add_titles_and_labels , add_gradient_plot
10+ from ...utils .plot_utils import make_figure , add_titles_and_labels
1111
1212
1313def loss (
@@ -17,13 +17,9 @@ def loss(
1717 smoothing_factor : float = 0.8 ,
1818 figsize : Sequence [float ] = None ,
1919 train_color : str = "#132a70" ,
20- val_color : str = None ,
21- val_colormap : str = "viridis" ,
20+ val_color : str = "black" ,
2221 lw_train : float = 2.0 ,
23- lw_val : float = 3.0 ,
24- marker : bool = True ,
25- val_marker_type : str = "." ,
26- val_marker_size : int = 34 ,
22+ lw_val : float = 2.0 ,
2723 grid_alpha : float = 0.2 ,
2824 legend_fontsize : int = 14 ,
2925 label_fontsize : int = 14 ,
@@ -41,29 +37,19 @@ def loss(
4137 The training loss key to look for in the history
4238 val_key : str, optional, default: "val_loss"
4339 The validation loss key to look for in the history
44- moving_average : bool, optional, default: False
45- A flag for adding an exponential moving average line of the train_losses.
46- moving_average_alpha : int, optional, default: 0.8
47- Smoothing factor for the moving average.
40+ smoothing_factor : float, optional, default: 0.8
41+ If greater than zero, smooth the loss curves by applying an exponential moving average.
4842 figsize : tuple or None, optional, default: None
4943 The figure size passed to the ``matplotlib`` constructor.
5044 Inferred if ``None``
5145 train_color : str, optional, default: '#8f2727'
5246 The color for the train loss trajectory
5347 val_color : str, optional, default: None
5448 The color for the optional validation loss trajectory
55- val_colormap : str, optional, default: "viridis"
56- The colormap for the optional validation loss trajectory
5749 lw_train : int, optional, default: 2
5850 The linewidth for the training loss curve
5951 lw_val : int, optional, default: 3
6052 The linewidth for the validation loss curve
61- marker : bool, optional, default: False
62- A flag for whether marker should be added in the validation loss trajectory
63- val_marker_type : str, optional, default: o
64- The marker type for the validation loss curve
65- val_marker_size : int, optional, default: 34
66- The marker size for the validation loss curve
6753 grid_alpha : float, optional, default: 0.2
6854 The transparency of the background grid
6955 legend_fontsize : int, optional, default: 14
@@ -108,68 +94,60 @@ def loss(
10894
10995 # Loop through loss entries and populate plot
11096 for i , ax in enumerate (axes .flat ):
111- # Plot train curve
112- ax .plot (train_step_index , train_losses .iloc [:, 0 ], color = train_color , lw = lw_train , alpha = 0.05 , label = "Training" )
113- if moving_average :
114- smoothed_train_loss = train_losses .iloc [:, 0 ].ewm (alpha = moving_average_alpha , adjust = True ).mean ()
115- ax .plot (train_step_index , smoothed_train_loss , color = "grey" , lw = lw_train , label = "Training (Moving Average)" )
97+ if smoothing_factor > 0 :
98+ # plot unsmoothed train loss
99+ ax .plot (
100+ train_step_index , train_losses .iloc [:, 0 ], color = train_color , lw = lw_train , alpha = 0.3 , label = "Training"
101+ )
102+
103+ # plot smoothed train loss
104+ smoothed_train_loss = train_losses .iloc [:, 0 ].ewm (alpha = 1.0 - smoothing_factor , adjust = True ).mean ()
105+ ax .plot (
106+ train_step_index ,
107+ smoothed_train_loss ,
108+ color = train_color ,
109+ lw = lw_train ,
110+ alpha = 0.8 ,
111+ label = "Training (Moving Average)" ,
112+ )
113+ else :
114+ # plot unsmoothed train loss
115+ ax .plot (
116+ train_step_index , train_losses .iloc [:, 0 ], color = train_color , lw = lw_train , alpha = 0.8 , label = "Training"
117+ )
116118
117119 # Plot optional val curve
118120 if val_losses is not None :
119121 if val_color is not None :
120- ax .plot (
121- val_step_index ,
122- val_losses .iloc [:, 0 ],
123- linestyle = "--" ,
124- marker = val_marker_type if marker else None ,
125- color = val_color ,
126- lw = lw_val ,
127- alpha = 0.2 ,
128- label = "Validation" ,
129- )
130- if moving_average :
131- smoothed_val_loss = val_losses .iloc [:, 0 ].ewm (alpha = moving_average_alpha , adjust = True ).mean ()
122+ if smoothing_factor > 0 :
123+ # plot unsmoothed val loss
124+ ax .plot (
125+ val_step_index , val_losses .iloc [:, 0 ], color = val_color , lw = lw_val , alpha = 0.3 , label = "Validation"
126+ )
127+
128+ # plot smoothed val loss
129+ smoothed_val_loss = val_losses .iloc [:, 0 ].ewm (alpha = 1.0 - smoothing_factor , adjust = True ).mean ()
132130 ax .plot (
133131 val_step_index ,
134132 smoothed_val_loss ,
135133 color = val_color ,
136134 lw = lw_val ,
135+ alpha = 0.8 ,
137136 label = "Validation (Moving Average)" ,
138137 )
139- else :
140- # Make gradient lines
141- add_gradient_plot (
142- val_step_index ,
143- val_losses .iloc [:, 0 ],
144- ax ,
145- val_colormap ,
146- lw_val ,
147- marker ,
148- val_marker_type ,
149- val_marker_size ,
150- alpha = 0.05 ,
151- label = "Validation" ,
152- )
153- if moving_average :
154- smoothed_val_loss = val_losses .iloc [:, 0 ].ewm (alpha = moving_average_alpha , adjust = True ).mean ()
155- add_gradient_plot (
156- val_step_index ,
157- smoothed_val_loss ,
158- ax ,
159- val_colormap ,
160- lw_val ,
161- marker ,
162- val_marker_type ,
163- val_marker_size ,
164- alpha = 1 ,
165- label = "Validation (Moving Average)" ,
138+ else :
139+ # plot unsmoothed val loss
140+ ax .plot (
141+ val_step_index , val_losses .iloc [:, 0 ], color = val_color , lw = lw_val , alpha = 0.8 , label = "Validation"
166142 )
167143
168144 sns .despine (ax = ax )
169145 ax .grid (alpha = grid_alpha )
170146
171- # Only add legend if there is a validation curve
172- if val_losses is not None or moving_average :
147+ ax .set_xlim (train_step_index [0 ], train_step_index [- 1 ])
148+
149+ # Only add the legend if there are multiple curves
150+ if val_losses is not None or smoothing_factor > 0 :
173151 ax .legend (fontsize = legend_fontsize )
174152
175153 # Add labels, titles, and set font sizes
0 commit comments