@@ -114,42 +114,43 @@ def loss(
114114 label = "Training (Moving Average)" ,
115115 )
116116 else :
117- # plot unsmoothed train loss
117+ # Plot unsmoothed train loss
118118 ax .plot (
119119 train_step_index , train_losses .iloc [:, 0 ], color = train_color , lw = lw_train , alpha = 0.8 , label = "Training"
120120 )
121121
122- # Plot optional val curve
123- if val_losses is not None :
124- if val_color is not None :
125- if smoothing_factor > 0 :
126- # plot unsmoothed val loss
127- ax .plot (
128- val_step_index , val_losses .iloc [:, 0 ], color = val_color , lw = lw_val , alpha = 0.3 , label = "Validation"
129- )
130-
131- # plot smoothed val loss
132- smoothed_val_loss = val_losses .iloc [:, 0 ].ewm (alpha = 1.0 - smoothing_factor , adjust = True ).mean ()
133- ax .plot (
134- val_step_index ,
135- smoothed_val_loss ,
136- color = val_color ,
137- lw = lw_val ,
138- alpha = 0.8 ,
139- label = "Validation (Moving Average)" ,
140- )
141- else :
142- # plot unsmoothed val loss
143- ax .plot (
144- val_step_index , val_losses .iloc [:, 0 ], color = val_color , lw = lw_val , alpha = 0.8 , label = "Validation"
145- )
122+ # Only plot if we actually have validation losses and a color assigned
123+ if val_losses is not None and val_color is not None :
124+ alpha_unsmoothed = 0.3 if smoothing_factor > 0 else 0.8
146125
126+ # Plot unsmoothed val loss
127+ ax .plot (
128+ val_step_index ,
129+ val_losses .iloc [:, 0 ],
130+ color = val_color ,
131+ lw = lw_val ,
132+ alpha = alpha_unsmoothed ,
133+ label = "Validation" ,
134+ )
135+
136+ # if requested, plot a second, smoothed curve
137+ if smoothing_factor > 0 :
138+ smoothed_val_loss = val_losses .iloc [:, 0 ].ewm (alpha = 1.0 - smoothing_factor , adjust = True ).mean ()
139+ ax .plot (
140+ val_step_index ,
141+ smoothed_val_loss ,
142+ color = val_color ,
143+ lw = lw_val ,
144+ alpha = 0.8 ,
145+ label = "Validation (Moving Average)" ,
146+ )
147+
148+ # rest of the styling
147149 sns .despine (ax = ax )
148150 ax .grid (alpha = grid_alpha )
149-
150151 ax .set_xlim (train_step_index [0 ], train_step_index [- 1 ])
151152
152- # Only add the legend if there are multiple curves
153+ # legend only if there's at least one validation curve or smoothing was on
153154 if val_losses is not None or smoothing_factor > 0 :
154155 ax .legend (fontsize = legend_fontsize )
155156
0 commit comments