Skip to content

Commit a5b6907

Browse files
Aliya Nigamovaanigamova
authored andcommitted
lint
1 parent 93d4db4 commit a5b6907

File tree

1 file changed

+66
-60
lines changed

1 file changed

+66
-60
lines changed

scripts/debugChains.py

Lines changed: 66 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
#!/usr/bin/env python3
22
import matplotlib.pyplot as plt
3-
#import HiggsAnalysis.CombinedLimit.util.plotting as plot
3+
4+
# import HiggsAnalysis.CombinedLimit.util.plotting as plot
45
import argparse, sys
56
import numpy as np
67
import ROOT
78

89
# recent versions of numpy complain but can ignore
910
import warnings
11+
1012
warnings.filterwarnings("ignore", message="The value of the smallest subnormal for")
1113

1214
ROOT.PyConfig.IgnoreCommandLineOptions = True
@@ -45,7 +47,7 @@
4547
type=float,
4648
default=0.95,
4749
help="""Confidence level to use for the interval""",
48-
)
50+
)
4951
parser.add_argument(
5052
"--burnInFraction",
5153
"-b",
@@ -65,51 +67,52 @@
6567
help="""Choose from interval (default)/upperlim/lowerlim""",
6668
)
6769

68-
modes = ["upperlim","lowerlim","interval"]
70+
modes = ["upperlim", "lowerlim", "interval"]
6971

7072
args = parser.parse_args()
7173

72-
if args.mode not in modes:
73-
print(f'ERROR, for --mode, must pick from',modes)
74-
sys.exit(0)
74+
if args.mode not in modes:
75+
print(f"ERROR, for --mode, must pick from", modes)
76+
sys.exit(0)
7577

7678

77-
def weighted_percentile(data,weights,perc):
79+
def weighted_percentile(data, weights, perc):
7880
# Source - https://stackoverflow.com/a/61343915
7981
# Posted by imbr, modified by community. See post 'Timeline' for change history
8082
# Retrieved 2026-03-03, License - CC BY-SA 4.0
8183
# based off https://en.wikipedia.org/wiki/Percentile#Definition_of_the_Weighted_Percentile_method
8284

8385
ix = np.argsort(data)
84-
data = data[ix] # sort data
85-
weights = weights[ix] # sort weights
86-
cdf = (np.cumsum(weights) - 0.5 * weights) / np.sum(weights) # 'like' a CDF function
86+
data = data[ix] # sort data
87+
weights = weights[ix] # sort weights
88+
cdf = (np.cumsum(weights) - 0.5 * weights) / np.sum(weights) # 'like' a CDF function
8789
return np.interp(perc, cdf, data)
8890

8991

90-
def findInterval(arr,weights,CL,mode='interval'):
92+
def findInterval(arr, weights, CL, mode="interval"):
9193
# have an array of values and a CL
92-
# find the interval that contains CL of them
94+
# find the interval that contains CL of them
95+
96+
if mode == "interval":
97+
left_q = (1 - CL) / 2
98+
right_q = 1 - left_q
9399

94-
if mode == 'interval':
95-
left_q = (1-CL)/2
96-
right_q = 1-left_q
97-
98-
elif mode == 'upperlim':
100+
elif mode == "upperlim":
99101
left_q = 0
100102
right_q = CL
101-
102-
elif mode == 'lowerlim':
103-
left_q = 1-CL
103+
104+
elif mode == "lowerlim":
105+
left_q = 1 - CL
104106
right_q = 1
105107

106-
q1 = weighted_percentile(arr,weights,left_q)
107-
q2 = weighted_percentile(arr,weights,right_q)
108-
return [q1,q2]
108+
q1 = weighted_percentile(arr, weights, left_q)
109+
q2 = weighted_percentile(arr, weights, right_q)
110+
return [q1, q2]
111+
109112

110113
fi_MCMC = ROOT.TFile.Open(args.input)
111114

112-
# array to hold the values of the parameter in the chain
115+
# array to hold the values of the parameter in the chain
113116
param_value_chunks = []
114117
param_weight_chunks = []
115118

@@ -124,11 +127,11 @@ def findInterval(arr,weights,CL,mode='interval'):
124127
mychain = k.ReadObj().GetAsDataSet()
125128
lenchain = mychain.numEntries()
126129
average_chain_length += lenchain
127-
burnin = int(lenchain*args.burnInFraction)
130+
burnin = int(lenchain * args.burnInFraction)
128131
start_idx = burnin + 1
129132
total_chains += 1
130133

131-
keep_gr = np.random.uniform(0,1)<args.chainsF or j==0
134+
keep_gr = np.random.uniform(0, 1) < args.chainsF or j == 0
132135
chain_vals = np.empty(lenchain, dtype=float)
133136
chain_weights = np.empty(lenchain, dtype=float)
134137

@@ -144,7 +147,7 @@ def findInterval(arr,weights,CL,mode='interval'):
144147
graphs.append(keep_gr)
145148
if keep_gr:
146149
kept_chain += 1
147-
150+
148151
all_graphs.append(chain_vals)
149152

150153
if param_value_chunks:
@@ -154,75 +157,78 @@ def findInterval(arr,weights,CL,mode='interval'):
154157
param_values = np.array([])
155158
param_weights = np.array([])
156159

157-
average_chain_length = float(average_chain_length)/(total_chains)
160+
average_chain_length = float(average_chain_length) / (total_chains)
158161

159-
plt.rcParams.update({'font.size': 12})
160-
fig,ax = plt.subplots(1,2,figsize=(14,5))
162+
plt.rcParams.update({"font.size": 12})
163+
fig, ax = plt.subplots(1, 2, figsize=(14, 5))
161164

162-
param_values = np.array(param_values,dtype=float)
163-
ax[0].hist(param_values, density=True, color='black', bins=args.nbins, range=args.range, weights=param_weights, histtype='step')
165+
param_values = np.array(param_values, dtype=float)
166+
ax[0].hist(param_values, density=True, color="black", bins=args.nbins, range=args.range, weights=param_weights, histtype="step")
164167
ax[0].set_xlabel(args.param)
165168
ax[0].set_ylabel("Posterior probability density")
166169

167-
interval = findInterval(param_values, param_weights,args.CL,args.mode)
170+
interval = findInterval(param_values, param_weights, args.CL, args.mode)
168171
print(f"Average chain length: {average_chain_length:.1f}")
169172
print(f"Number of chains: {j+1}")
170173
print(f"Burn-in fraction: {args.burnInFraction:.2f} (average burn-in length: {average_chain_length*args.burnInFraction:.1f} entries)")
171174

172-
if args.mode=="interval":
173-
label =f"{args.CL*100:.1f}% CL interval: {interval[0]:.3f} < {args.param} < {interval[1]:.3f}"
174-
elif args.mode=="upperlim":
175-
label =f"{args.CL*100:.1f}% CL interval: {args.param} < {interval[1]:.3f}"
176-
elif args.mode=="lowerlim":
177-
label =f"{args.CL*100:.1f}% CL interval: {args.param} > {interval[0]:.3f}"
175+
if args.mode == "interval":
176+
label = f"{args.CL*100:.1f}% CL interval: {interval[0]:.3f} < {args.param} < {interval[1]:.3f}"
177+
elif args.mode == "upperlim":
178+
label = f"{args.CL*100:.1f}% CL interval: {args.param} < {interval[1]:.3f}"
179+
elif args.mode == "lowerlim":
180+
label = f"{args.CL*100:.1f}% CL interval: {args.param} > {interval[0]:.3f}"
178181

179182
print(label)
180-
ax[0].axvline(interval[0], color='red', linestyle='--', label=label)
181-
ax[0].axvline(interval[1], color='red', linestyle='--')
182-
# put legend in the upper left corner, outside the plot
183-
ax[0].legend(loc='upper left', bbox_to_anchor=(.01, 1.14))
183+
ax[0].axvline(interval[0], color="red", linestyle="--", label=label)
184+
ax[0].axvline(interval[1], color="red", linestyle="--")
185+
# put legend in the upper left corner, outside the plot
186+
ax[0].legend(loc="upper left", bbox_to_anchor=(0.01, 1.14))
184187

185-
for k,gr in enumerate(all_graphs):
186-
if graphs[k]: ax[1].plot(np.arange(len(gr)),gr, color='black', marker=None, linestyle='-',linewidth=0.2, alpha=0.4)
188+
for k, gr in enumerate(all_graphs):
189+
if graphs[k]:
190+
ax[1].plot(np.arange(len(gr)), gr, color="black", marker=None, linestyle="-", linewidth=0.2, alpha=0.4)
187191

188192
ax[1].set_ylabel(args.param)
189-
ax[1].axvline(args.burnInFraction*average_chain_length, color='blue', linestyle='--', label="Burn-in fraction")
193+
ax[1].axvline(args.burnInFraction * average_chain_length, color="blue", linestyle="--", label="Burn-in fraction")
190194
ax[1].set_xlabel("Chain index")
191195
ax[1].set_title(f"Trace plot of {kept_chain} chains / {j+1} chains")
192196

193197
# make a sliding average and 68% interval plot on top of the trace plot
194198
# this should be across the graphs and take ~5% of the average chain length as the window size
195-
window_size = int(average_chain_length*0.05)
196-
num_windows = int(average_chain_length/window_size)
199+
window_size = int(average_chain_length * 0.05)
200+
num_windows = int(average_chain_length / window_size)
197201
running_avg = np.empty(num_windows)
198202
running_avg_upper = np.empty(num_windows)
199-
running_avg_lower = np.empty(num_windows)
203+
running_avg_lower = np.empty(num_windows)
200204
window_centers = np.empty(num_windows)
201205

202206
for i in range(num_windows):
203207
window_vals = []
204208
window_weights = []
205209
for gr in all_graphs:
206-
if (i+1)*window_size > len(gr): continue
207-
window_vals.append(gr[i*window_size:(i+1)*window_size])
208-
window_weights.append(np.ones(window_size)) # equal weights for the running average
210+
if (i + 1) * window_size > len(gr):
211+
continue
212+
window_vals.append(gr[i * window_size : (i + 1) * window_size])
213+
window_weights.append(np.ones(window_size)) # equal weights for the running average
209214
window_vals = np.concatenate(window_vals)
210215
window_weights = np.concatenate(window_weights)
211216
running_avg[i] = np.average(window_vals, weights=window_weights)
212-
interval = findInterval(window_vals, window_weights,0.68,mode='interval')
217+
interval = findInterval(window_vals, window_weights, 0.68, mode="interval")
213218
running_avg_lower[i] = interval[0]
214219
running_avg_upper[i] = interval[1]
215-
window_center = (i*window_size)+window_size/2
220+
window_center = (i * window_size) + window_size / 2
216221
window_centers[i] = window_center
217222

218-
ax[1].plot(window_centers,running_avg, color='red', marker=None, linestyle='-',linewidth=2, label="Sliding average")
219-
ax[1].fill_between(window_centers, running_avg_lower, running_avg_upper, color='red', alpha=0.5, label="68% interval")
220-
ax[1].legend(loc='upper right')
223+
ax[1].plot(window_centers, running_avg, color="red", marker=None, linestyle="-", linewidth=2, label="Sliding average")
224+
ax[1].fill_between(window_centers, running_avg_lower, running_avg_upper, color="red", alpha=0.5, label="68% interval")
225+
ax[1].legend(loc="upper right")
221226

222227
if args.range:
223-
ax[1].set_ylim(args.range[0],args.range[1])
228+
ax[1].set_ylim(args.range[0], args.range[1])
224229

225-
if not args.output : args.output = args.param
230+
if not args.output:
231+
args.output = args.param
226232
plt.savefig(args.output + ".pdf")
227233
plt.savefig(args.output + ".png")
228-
print(f'Saved output as {args.output}.pdf/png')
234+
print(f"Saved output as {args.output}.pdf/png")

0 commit comments

Comments
 (0)