@@ -14,14 +14,13 @@ def loss(
1414 history : keras .callbacks .History ,
1515 train_key : str = "loss" ,
1616 val_key : str = "val_loss" ,
17- moving_average : bool = False ,
18- per_training_step : bool = False ,
19- ma_window_fraction : float = 0.01 ,
17+ smoothing_factor : float = 0.8 ,
2018 figsize : Sequence [float ] = None ,
2119 train_color : str = "#132a70" ,
2220 val_color : str = "black" ,
2321 lw_train : float = 2.0 ,
24- lw_val : float = 3.0 ,
22+ lw_val : float = 2.0 ,
23+ grid_alpha : float = 0.2 ,
2524 legend_fontsize : int = 14 ,
2625 label_fontsize : int = 14 ,
2726 title_fontsize : int = 16 ,
@@ -38,24 +37,21 @@ def loss(
3837 The training loss key to look for in the history
3938 val_key : str, optional, default: "val_loss"
4039 The validation loss key to look for in the history
41- moving_average : bool, optional, default: False
42- A flag for adding a moving average line of the train_losses.
43- per_training_step : bool, optional, default: False
44- A flag for making loss trajectory detailed (to training steps) rather than per epoch.
45- ma_window_fraction : int, optional, default: 0.01
46- Window size for the moving average as a fraction of total
47- training steps.
40+ smoothing_factor : float, optional, default: 0.8
41+ If greater than zero, smooth the loss curves by applying an exponential moving average.
4842 figsize : tuple or None, optional, default: None
4943 The figure size passed to the ``matplotlib`` constructor.
5044 Inferred if ``None``
5145 train_color : str, optional, default: '#8f2727'
5246 The color for the train loss trajectory
53- val_color : str, optional, default: black
47+ val_color : str, optional, default: None
5448 The color for the optional validation loss trajectory
5549 lw_train : int, optional, default: 2
5650 The linewidth for the training loss curve
5751 lw_val : int, optional, default: 3
5852 The linewidth for the validation loss curve
53+ grid_alpha : float, optional, default: 0.2
54+ The transparency of the background grid
5955 legend_fontsize : int, optional, default: 14
6056 The font size of the legend text
6157 label_fontsize : int, optional, default: 14
@@ -98,31 +94,60 @@ def loss(
9894
9995 # Loop through loss entries and populate plot
10096 for i , ax in enumerate (axes .flat ):
101- # Plot train curve
102- ax .plot (train_step_index , train_losses .iloc [:, i ], color = train_color , lw = lw_train , alpha = 0.9 , label = "Training" )
103- if moving_average and train_losses .columns [i ] == "Loss" :
104- moving_average_window = int (train_losses .shape [0 ] * ma_window_fraction )
105- smoothed_loss = train_losses .iloc [:, i ].rolling (window = moving_average_window ).mean ()
106- ax .plot (train_step_index , smoothed_loss , color = "grey" , lw = lw_train , label = "Training (Moving Average)" )
97+ if smoothing_factor > 0 :
98+ # plot unsmoothed train loss
99+ ax .plot (
100+ train_step_index , train_losses .iloc [:, 0 ], color = train_color , lw = lw_train , alpha = 0.3 , label = "Training"
101+ )
102+
103+ # plot smoothed train loss
104+ smoothed_train_loss = train_losses .iloc [:, 0 ].ewm (alpha = 1.0 - smoothing_factor , adjust = True ).mean ()
105+ ax .plot (
106+ train_step_index ,
107+ smoothed_train_loss ,
108+ color = train_color ,
109+ lw = lw_train ,
110+ alpha = 0.8 ,
111+ label = "Training (Moving Average)" ,
112+ )
113+ else :
114+ # plot unsmoothed train loss
115+ ax .plot (
116+ train_step_index , train_losses .iloc [:, 0 ], color = train_color , lw = lw_train , alpha = 0.8 , label = "Training"
117+ )
107118
108119 # Plot optional val curve
109120 if val_losses is not None :
110- if i < val_losses .shape [1 ]:
111- ax .plot (
112- val_step_index ,
113- val_losses .iloc [:, i ],
114- linestyle = "--" ,
115- marker = "o" ,
116- color = val_color ,
117- lw = lw_val ,
118- label = "Validation" ,
119- )
121+ if val_color is not None :
122+ if smoothing_factor > 0 :
123+ # plot unsmoothed val loss
124+ ax .plot (
125+ val_step_index , val_losses .iloc [:, 0 ], color = val_color , lw = lw_val , alpha = 0.3 , label = "Validation"
126+ )
127+
128+ # plot smoothed val loss
129+ smoothed_val_loss = val_losses .iloc [:, 0 ].ewm (alpha = 1.0 - smoothing_factor , adjust = True ).mean ()
130+ ax .plot (
131+ val_step_index ,
132+ smoothed_val_loss ,
133+ color = val_color ,
134+ lw = lw_val ,
135+ alpha = 0.8 ,
136+ label = "Validation (Moving Average)" ,
137+ )
138+ else :
139+ # plot unsmoothed val loss
140+ ax .plot (
141+ val_step_index , val_losses .iloc [:, 0 ], color = val_color , lw = lw_val , alpha = 0.8 , label = "Validation"
142+ )
120143
121144 sns .despine (ax = ax )
122- ax .grid (alpha = 0.5 )
145+ ax .grid (alpha = grid_alpha )
123146
124- # Only add legend if there is a validation curve
125- if val_losses is not None or moving_average :
147+ ax .set_xlim (train_step_index [0 ], train_step_index [- 1 ])
148+
149+ # Only add the legend if there are multiple curves
150+ if val_losses is not None or smoothing_factor > 0 :
126151 ax .legend (fontsize = legend_fontsize )
127152
128153 # Add labels, titles, and set font sizes
@@ -131,7 +156,7 @@ def loss(
131156 num_row = num_row ,
132157 num_col = 1 ,
133158 title = ["Loss Trajectory" ],
134- xlabel = "Training step #" if per_training_step else "Training epoch #" ,
159+ xlabel = "Training epoch #" ,
135160 ylabel = "Value" ,
136161 title_fontsize = title_fontsize ,
137162 label_fontsize = label_fontsize ,
0 commit comments