Skip to content

Commit 5b8157c

Browse files
Nathan SimpsonNathan Simpson
andauthored
Change API to make more flexible in practice (#33)
* total change * format * format * format * oh well then * more * pre-commit * format * add more examples Co-authored-by: Nathan Simpson <phinate@protonmail.com>
1 parent 17b85c0 commit 5b8157c

26 files changed

+3605
-4524
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
repos:
22
- repo: https://github.com/psf/black
3-
rev: 22.1.0
3+
rev: 22.3.0
44
hooks:
55
- id: black-jupyter
66

demo.ipynb

Lines changed: 1272 additions & 383 deletions
Large diffs are not rendered by default.

examples/ap000.gif

213 KB
Loading

examples/binning.ipynb

Lines changed: 1081 additions & 0 deletions
Large diffs are not rendered by default.

examples/cuts.ipynb

Lines changed: 457 additions & 0 deletions
Large diffs are not rendered by default.

examples/diffable_histograms.ipynb

Lines changed: 380 additions & 0 deletions
Large diffs are not rendered by default.

examples/float.png

30.2 KB
Loading

examples/requirements.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
celluloid
2+
git+http://github.com/scikit-hep/pyhf.git@make_difffable_model_ctor
3+
plothelp
Lines changed: 234 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,234 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": null,
6+
"metadata": {},
7+
"outputs": [],
8+
"source": [
9+
"import jax\n",
10+
"import jax.numpy as jnp\n",
11+
"import matplotlib.pyplot as plt\n",
12+
"import optax\n",
13+
"from jaxopt import OptaxSolver\n",
14+
"import relaxed\n",
15+
"from celluloid import Camera\n",
16+
"from functools import partial\n",
17+
"import matplotlib.lines as mlines\n",
18+
"\n",
19+
"# matplotlib settings\n",
20+
"plt.rc(\"figure\", figsize=(6, 3), dpi=220, facecolor=\"w\")\n",
21+
"plt.rc(\"legend\", fontsize=6)"
22+
]
23+
},
24+
{
25+
"cell_type": "markdown",
26+
"metadata": {},
27+
"source": [
28+
"# Optimising a simple one-bin analysis with `relaxed`\n",
29+
"\n",
30+
"Let's define an analysis with a predicted number of signal and background events, with some uncertainty on the background estimate. We'll abstract the analysis configuration into a single parameter $\\phi$ like so:\n",
31+
"\n",
32+
"$$s = 15 + \\phi $$\n",
33+
"$$b = 45 - 2 \\phi $$\n",
34+
"$$\\sigma_b = 0.5 + 0.1*\\phi^2 $$\n",
35+
"\n",
36+
"Note that $s \\propto \\phi$ and $\\propto -2\\phi$, so increasing $\\phi$ corresponds to increasing the signal/backround ratio. However, our uncertainty scales like $\\phi^2$, so we're also going to compromise in our certainty of the background count as we do that. This kind of tradeoff between $s/b$ ratio and uncertainty is important for the discovery of a new signal, so we can't get away with optimising $s/b$ alone.\n",
37+
"\n",
38+
"To illustrate this, we'll plot the discovery significance for this model with and without uncertainty."
39+
]
40+
},
41+
{
42+
"cell_type": "code",
43+
"execution_count": null,
44+
"metadata": {},
45+
"outputs": [],
46+
"source": [
47+
"# model definition\n",
48+
"def yields(phi, uncertainty=True):\n",
49+
" s = 15 + phi\n",
50+
" b = 45 - 2 * phi\n",
51+
" db = (\n",
52+
" 0.5 + 0.1 * phi**2 if uncertainty else jnp.zeros_like(phi) + 0.001\n",
53+
" ) # small enough to be negligible\n",
54+
" return jnp.asarray([s]), jnp.asarray([b]), jnp.asarray([db])\n",
55+
"\n",
56+
"\n",
57+
"# our analysis pipeline, from phi to p-value\n",
58+
"def pipeline(phi, return_yields=False, uncertainty=True):\n",
59+
" y = yields(phi, uncertainty=uncertainty)\n",
60+
" # use a dummy version of pyhf for simplicity + compatibility with jax\n",
61+
" model = relaxed.dummy_pyhf.uncorrelated_background(*y)\n",
62+
" nominal_pars = jnp.array([1.0, 1.0])\n",
63+
" data = model.expected_data(nominal_pars) # we expect the nominal model\n",
64+
" # do the hypothesis test (and fit model pars with gradient descent)\n",
65+
" pvalue = relaxed.infer.hypotest(\n",
66+
" 0.0, # value of mu for the alternative hypothesis\n",
67+
" data,\n",
68+
" model,\n",
69+
" test_stat=\"q0\", # discovery significance test\n",
70+
" lr=1e-3,\n",
71+
" expected_pars=nominal_pars, # optionally providing MLE pars in advance\n",
72+
" )\n",
73+
" if return_yields:\n",
74+
" return pvalue, y\n",
75+
" else:\n",
76+
" return pvalue\n",
77+
"\n",
78+
"\n",
79+
"# calculate p-values for a range of phi values\n",
80+
"phis = jnp.linspace(0, 10, 100)\n",
81+
"\n",
82+
"# with uncertainty\n",
83+
"pipe = partial(pipeline, return_yields=True, uncertainty=True)\n",
84+
"pvals, ys = jax.vmap(pipe)(phis) # map over phi grid\n",
85+
"# without uncertainty\n",
86+
"pipe_no_uncertainty = partial(pipeline, uncertainty=False)\n",
87+
"pvals_no_uncertainty = jax.vmap(pipe_no_uncertainty)(phis)"
88+
]
89+
},
90+
{
91+
"cell_type": "code",
92+
"execution_count": null,
93+
"metadata": {},
94+
"outputs": [],
95+
"source": [
96+
"fig, axs = plt.subplots(2, 1, sharex=True)\n",
97+
"axs[0].plot(phis, pvals, label=\"with uncertainty\", color=\"C2\")\n",
98+
"axs[0].plot(phis, pvals_no_uncertainty, label=\"no uncertainty\", color=\"C4\")\n",
99+
"axs[0].set_ylabel(\"$p$-value\")\n",
100+
"# plot vertical dotted line at minimum of p-values + s/b\n",
101+
"best_phi = phis[jnp.argmin(pvals)]\n",
102+
"axs[0].axvline(x=best_phi, linestyle=\"dotted\", color=\"C2\", label=\"optimal p-value\")\n",
103+
"axs[0].axvline(\n",
104+
" x=phis[jnp.argmin(pvals_no_uncertainty)],\n",
105+
" linestyle=\"dotted\",\n",
106+
" color=\"C4\",\n",
107+
" label=r\"optimal $s/b$\",\n",
108+
")\n",
109+
"axs[0].legend(loc=\"upper left\", ncol=2)\n",
110+
"s, b, db = ys\n",
111+
"s, b, db = s.ravel(), b.ravel(), db.ravel() # everything is [[x]] for pyhf\n",
112+
"axs[1].fill_between(phis, s + b, b, color=\"C9\", label=\"signal\")\n",
113+
"axs[1].fill_between(phis, b, color=\"C1\", label=\"background\")\n",
114+
"axs[1].fill_between(phis, b - db, b + db, facecolor=\"k\", alpha=0.2, label=r\"$\\sigma_b$\")\n",
115+
"axs[1].set_xlabel(\"$\\phi$\")\n",
116+
"axs[1].set_ylabel(\"yield\")\n",
117+
"axs[1].legend(loc=\"lower left\")\n",
118+
"plt.suptitle(\"Discovery p-values, with and without uncertainty\")\n",
119+
"plt.tight_layout()"
120+
]
121+
},
122+
{
123+
"cell_type": "markdown",
124+
"metadata": {},
125+
"source": [
126+
"Using gradient descent, we can optimise this analysis in an uncertainty-aware way by directly optimising $\\phi$ for the lowest discovery p-value. Here's how you do that:"
127+
]
128+
},
129+
{
130+
"cell_type": "code",
131+
"execution_count": null,
132+
"metadata": {},
133+
"outputs": [],
134+
"source": [
135+
"# The fast way!\n",
136+
"# use the OptaxSolver wrapper from jaxopt to perform the minimisation\n",
137+
"# set a couple of tolerance kwargs to make sure we don't get stuck\n",
138+
"solver = OptaxSolver(pipeline, opt=optax.adam(1e-3), tol=1e-8, maxiter=10000)\n",
139+
"pars = 9.0 # random init\n",
140+
"result = solver.run(pars).params\n",
141+
"print(\n",
142+
" f\"our solution: phi={result:.5f}\\ntrue optimum: phi={phis[jnp.argmin(pvals)]:.5f}\\nbest s/b: phi=10\"\n",
143+
")"
144+
]
145+
},
146+
{
147+
"cell_type": "code",
148+
"execution_count": null,
149+
"metadata": {},
150+
"outputs": [],
151+
"source": [
152+
"# The longer way (but with plots)!\n",
153+
"pipe = partial(pipeline, return_yields=True, uncertainty=True)\n",
154+
"solver = OptaxSolver(pipe, opt=optax.adam(1e-1), has_aux=True)\n",
155+
"pars = 9.0\n",
156+
"state = solver.init_state(pars) # we're doing init, update steps instead of .run()\n",
157+
"\n",
158+
"plt.rc(\"figure\", figsize=(6, 3), dpi=220, facecolor=\"w\")\n",
159+
"plt.rc(\"legend\", fontsize=8)\n",
160+
"fig, axs = plt.subplots(1, 2)\n",
161+
"cam = Camera(fig)\n",
162+
"steps = 5 # increase me for better results! (100ish works well)\n",
163+
"for i in range(steps):\n",
164+
" pars, state = solver.update(pars, state)\n",
165+
" s, b, db = state.aux\n",
166+
" val = state.value\n",
167+
" ax = axs[0]\n",
168+
" cv = ax.plot(phis, pvals, c=\"C0\")\n",
169+
" cvs = ax.plot(phis, pvals_no_uncertainty, c=\"green\")\n",
170+
" current = ax.scatter(pars, val, c=\"C0\")\n",
171+
" ax.set_xlabel(r\"analysis config $\\phi$\")\n",
172+
" ax.set_ylabel(\"p-value\")\n",
173+
" ax.legend(\n",
174+
" [\n",
175+
" mlines.Line2D([], [], color=\"C0\"),\n",
176+
" mlines.Line2D([], [], color=\"green\"),\n",
177+
" current,\n",
178+
" ],\n",
179+
" [\"p-value (with uncert)\", \"p-value (without uncert)\", \"current value\"],\n",
180+
" frameon=False,\n",
181+
" )\n",
182+
" ax.text(0.3, 0.61, f\"step {i}\", transform=ax.transAxes)\n",
183+
" ax = axs[1]\n",
184+
" ax.set_ylim((0, 80))\n",
185+
" b1 = ax.bar(0.5, b, facecolor=\"C1\", label=\"b\")\n",
186+
" b2 = ax.bar(0.5, s, bottom=b, facecolor=\"C9\", label=\"s\")\n",
187+
" b3 = ax.bar(\n",
188+
" 0.5, db, bottom=b - db / 2, facecolor=\"k\", alpha=0.5, label=r\"$\\sigma_b$\"\n",
189+
" )\n",
190+
" ax.set_ylabel(\"yield\")\n",
191+
" ax.set_xticks([])\n",
192+
" ax.legend([b1, b2, b3], [\"b\", \"s\", r\"$\\sigma_b$\"], frameon=False)\n",
193+
" plt.tight_layout()\n",
194+
" cam.snap()\n",
195+
"\n",
196+
"ani = cam.animate()\n",
197+
"# uncomment this to save and view the animation!\n",
198+
"# ani.save(\"ap00.gif\", fps=9)"
199+
]
200+
},
201+
{
202+
"cell_type": "code",
203+
"execution_count": null,
204+
"metadata": {},
205+
"outputs": [],
206+
"source": []
207+
}
208+
],
209+
"metadata": {
210+
"interpreter": {
211+
"hash": "22d6333b89854cd01c2018f3ca2f5a59a2cde2765fbca789ff36cfad48ca629b"
212+
},
213+
"kernelspec": {
214+
"display_name": "Python 3.9.12 ('venv': venv)",
215+
"language": "python",
216+
"name": "python3"
217+
},
218+
"language_info": {
219+
"codemirror_mode": {
220+
"name": "ipython",
221+
"version": 3
222+
},
223+
"file_extension": ".py",
224+
"mimetype": "text/x-python",
225+
"name": "python",
226+
"nbconvert_exporter": "python",
227+
"pygments_lexer": "ipython3",
228+
"version": "3.9.12"
229+
},
230+
"orig_nbformat": 4
231+
},
232+
"nbformat": 4,
233+
"nbformat_minor": 2
234+
}

examples/withbinfloat.png

27.4 KB
Loading

0 commit comments

Comments
 (0)