@@ -15,11 +15,13 @@ def loss(
1515 train_key : str = "loss" ,
1616 val_key : str = "val_loss" ,
1717 per_training_step : bool = False ,
18+ smoothing_factor : float = 0.8 ,
1819 figsize : Sequence [float ] = None ,
1920 train_color : str = "#132a70" ,
2021 val_color : str = "black" ,
21- lw_train : float = 2.5 ,
22- lw_val : float = 2.5 ,
22+ lw_train : float = 2.0 ,
23+ lw_val : float = 2.0 ,
24+ grid_alpha : float = 0.2 ,
2325 legend_fontsize : int = 14 ,
2426 label_fontsize : int = 14 ,
2527 title_fontsize : int = 16 ,
@@ -38,17 +40,21 @@ def loss(
3840 The validation loss key to look for in the history
3941 per_training_step : bool, optional, default: False
4042 A flag for making loss trajectory detailed (to training steps) rather than per epoch.
43+ smoothing_factor : float, optional, default: 0.8
44+ If greater than zero, smooth the loss curves by applying an exponential moving average.
4145 figsize : tuple or None, optional, default: None
4246 The figure size passed to the ``matplotlib`` constructor.
4347 Inferred if ``None``
44- train_color : str, optional, default: '#8f2727 '
48+ train_color : str, optional, default: '#132a70 '
4549 The color for the train loss trajectory
46- val_color : str, optional, default: black
50+ val_color : str, optional, default: None
4751 The color for the optional validation loss trajectory
48- lw_train : int, optional, default: 1
52+ lw_train : int, optional, default: 2
4953 The linewidth for the training loss curve
5054 lw_val : int, optional, default: 2
5155 The linewidth for the validation loss curve
56+ grid_alpha : float, optional, default: 0.2
57+ The transparency of the background grid
5258 legend_fontsize : int, optional, default: 14
5359 The font size of the legend text
5460 label_fontsize : int, optional, default: 14
@@ -91,28 +97,60 @@ def loss(
9197
9298 # Loop through loss entries and populate plot
9399 for i , ax in enumerate (axes .flat ):
94- # Plot train curve
95- ax .plot (train_step_index , train_losses .iloc [:, i ], color = train_color , lw = lw_train , alpha = 0.9 , label = "Training" )
100+ if smoothing_factor > 0 :
101+ # plot unsmoothed train loss
102+ ax .plot (
103+ train_step_index , train_losses .iloc [:, 0 ], color = train_color , lw = lw_train , alpha = 0.3 , label = "Training"
104+ )
105+
106+ # plot smoothed train loss
107+ smoothed_train_loss = train_losses .iloc [:, 0 ].ewm (alpha = 1.0 - smoothing_factor , adjust = True ).mean ()
108+ ax .plot (
109+ train_step_index ,
110+ smoothed_train_loss ,
111+ color = train_color ,
112+ lw = lw_train ,
113+ alpha = 0.8 ,
114+ label = "Training (Moving Average)" ,
115+ )
116+ else :
117+ # plot unsmoothed train loss
118+ ax .plot (
119+ train_step_index , train_losses .iloc [:, 0 ], color = train_color , lw = lw_train , alpha = 0.8 , label = "Training"
120+ )
96121
97122 # Plot optional val curve
98123 if val_losses is not None :
99- if i < val_losses .shape [1 ]:
100- ax .plot (
101- val_step_index ,
102- val_losses .iloc [:, i ],
103- linestyle = "--" ,
104- marker = "o" ,
105- markersize = 5 ,
106- color = val_color ,
107- lw = lw_val ,
108- label = "Validation" ,
109- )
124+ if val_color is not None :
125+ if smoothing_factor > 0 :
126+ # plot unsmoothed val loss
127+ ax .plot (
128+ val_step_index , val_losses .iloc [:, 0 ], color = val_color , lw = lw_val , alpha = 0.3 , label = "Validation"
129+ )
130+
131+ # plot smoothed val loss
132+ smoothed_val_loss = val_losses .iloc [:, 0 ].ewm (alpha = 1.0 - smoothing_factor , adjust = True ).mean ()
133+ ax .plot (
134+ val_step_index ,
135+ smoothed_val_loss ,
136+ color = val_color ,
137+ lw = lw_val ,
138+ alpha = 0.8 ,
139+ label = "Validation (Moving Average)" ,
140+ )
141+ else :
142+ # plot unsmoothed val loss
143+ ax .plot (
144+ val_step_index , val_losses .iloc [:, 0 ], color = val_color , lw = lw_val , alpha = 0.8 , label = "Validation"
145+ )
110146
111147 sns .despine (ax = ax )
112- ax .grid (alpha = 0.5 )
148+ ax .grid (alpha = grid_alpha )
113149
114- # Only add legend if there is a validation curve
115- if val_losses is not None :
150+ ax .set_xlim (train_step_index [0 ], train_step_index [- 1 ])
151+
152+ # Only add the legend if there are multiple curves
153+ if val_losses is not None or smoothing_factor > 0 :
116154 ax .legend (fontsize = legend_fontsize )
117155
118156 # Add labels, titles, and set font sizes
0 commit comments