@@ -190,28 +190,33 @@ def findInterval(arr,weights,CL,mode='interval'):
190190ax [1 ].set_xlabel ("Chain index" )
191191ax [1 ].set_title (f"Trace plot of { kept_chain } chains / { j + 1 } chains" )
192192
193- # make a running average and 68% interval plot on top of the trace plot
193+ # make a sliding average and 68% interval plot on top of the trace plot
194194# this should be across the graphs and take ~5% of the average chain length as the window size
195195window_size = int (average_chain_length * 0.05 )
196- running_avg = np .empty (len (all_graphs [0 ])- window_size )
197- running_avg_upper = np .empty (len (all_graphs [0 ])- window_size )
198- running_avg_lower = np .empty (len (all_graphs [0 ])- window_size )
196+ num_windows = int (average_chain_length / window_size )
197+ running_avg = np .empty (num_windows )
198+ running_avg_upper = np .empty (num_windows )
199+ running_avg_lower = np .empty (num_windows )
200+ window_centers = np .empty (num_windows )
199201
200- for i in range (len ( all_graphs [ 0 ]) - window_size ):
202+ for i in range (num_windows ):
201203 window_vals = []
202204 window_weights = []
203205 for gr in all_graphs :
204- if i + window_size < len (gr ):
205- window_vals .append (gr [i : i + window_size ])
206- window_weights .append (np .ones (window_size )) # equal weights for the running average
206+ if ( i + 1 ) * window_size > len (gr ): continue
207+ window_vals .append (gr [i * window_size :( i + 1 ) * window_size ])
208+ window_weights .append (np .ones (window_size )) # equal weights for the running average
207209 window_vals = np .concatenate (window_vals )
208210 window_weights = np .concatenate (window_weights )
209211 running_avg [i ] = np .average (window_vals , weights = window_weights )
210212 interval = findInterval (window_vals , window_weights ,0.68 ,mode = 'interval' )
211213 running_avg_lower [i ] = interval [0 ]
212214 running_avg_upper [i ] = interval [1 ]
213- ax [1 ].plot (np .arange (len (running_avg )),running_avg , color = 'red' , marker = None , linestyle = '-' ,linewidth = 2 , label = "Running average" )
214- ax [1 ].fill_between (np .arange (len (running_avg )), running_avg_lower , running_avg_upper , color = 'red' , alpha = 0.5 , label = "68% interval" )
215+ window_center = (i * window_size )+ window_size / 2
216+ window_centers [i ] = window_center
217+
218+ ax [1 ].plot (window_centers ,running_avg , color = 'red' , marker = None , linestyle = '-' ,linewidth = 2 , label = "Sliding average" )
219+ ax [1 ].fill_between (window_centers , running_avg_lower , running_avg_upper , color = 'red' , alpha = 0.5 , label = "68% interval" )
215220ax [1 ].legend (loc = 'upper right' )
216221
217222if args .range :
0 commit comments