Skip to content

Commit 6070627

Browse files
committed
Set up working infrastructure for batched KF
1 parent dcad7f2 commit 6070627

File tree

3 files changed

+431
-3
lines changed

3 files changed

+431
-3
lines changed

conda-envs/environment-test.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,4 @@ dependencies:
1818
- pip:
1919
- jax
2020
- blackjax
21+
- -e .

notebooks/batch-examples.ipynb

Lines changed: 388 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,388 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": 1,
6+
"id": "0a5841d3",
7+
"metadata": {},
8+
"outputs": [],
9+
"source": [
10+
"import numpy as np\n",
11+
"import pytensor\n",
12+
"import pytensor.tensor as pt\n",
13+
"from pymc_extras.statespace.filters import StandardFilter\n",
14+
"from tests.statespace.utilities.test_helpers import make_test_inputs\n",
15+
"from pytensor.graph.replace import vectorize_graph\n",
16+
"from importlib import reload\n",
17+
"import pymc_extras.statespace.filters.distributions as pmss_dist\n",
18+
"from pymc_extras.statespace.filters.distributions import SequenceMvNormal\n",
19+
"import pymc as pm"
20+
]
21+
},
22+
{
23+
"cell_type": "code",
24+
"execution_count": 2,
25+
"id": "14299e50",
26+
"metadata": {},
27+
"outputs": [],
28+
"source": [
29+
"seed = sum(map(ord, \"batched-kf\"))\n",
30+
"rng = np.random.default_rng(seed)"
31+
]
32+
},
33+
{
34+
"cell_type": "code",
35+
"execution_count": 3,
36+
"id": "71bc513e",
37+
"metadata": {},
38+
"outputs": [],
39+
"source": [
40+
"def create_batch_inputs(batch_size, p=1, m=5, r=1, n=10, rng=rng):\n",
41+
" \"\"\"\n",
42+
" Create batched inputs for testing.\n",
43+
"\n",
44+
" Parameters\n",
45+
" ----------\n",
46+
" batch_size : int\n",
47+
" Number of batches to create\n",
48+
" p : int\n",
49+
" First dimension parameter\n",
50+
" m : int\n",
51+
" Second dimension parameter\n",
52+
" r : int\n",
53+
" Third dimension parameter\n",
54+
" n : int\n",
55+
" Fourth dimension parameter\n",
56+
" rng : numpy.random.Generator\n",
57+
" Random number generator\n",
58+
"\n",
59+
" Returns\n",
60+
" -------\n",
61+
" list\n",
62+
" List of stacked inputs for each batch\n",
63+
" \"\"\"\n",
64+
" # Create individual inputs for each batch\n",
65+
" np_batch_inputs = []\n",
66+
" for i in range(batch_size):\n",
67+
" inputs = make_test_inputs(p, m, r, n, rng)\n",
68+
" np_batch_inputs.append(inputs)\n",
69+
"\n",
70+
" return [np.stack(x, axis=0) for x in zip(*np_batch_inputs)]"
71+
]
72+
},
73+
{
74+
"cell_type": "code",
75+
"execution_count": 4,
76+
"id": "0c1824cf",
77+
"metadata": {},
78+
"outputs": [
79+
{
80+
"data": {
81+
"text/plain": [
82+
"(3, 10, 1)"
83+
]
84+
},
85+
"execution_count": 4,
86+
"metadata": {},
87+
"output_type": "execute_result"
88+
}
89+
],
90+
"source": [
91+
"# Create batch inputs with batch size 3\n",
92+
"np_batch_inputs = create_batch_inputs(3)\n",
93+
"np_batch_inputs[0].shape"
94+
]
95+
},
96+
{
97+
"cell_type": "code",
98+
"execution_count": 5,
99+
"id": "773d4cb4",
100+
"metadata": {},
101+
"outputs": [],
102+
"source": [
103+
"p, m, r, n = 1, 5, 1, 10\n",
104+
"inputs = [pt.as_tensor(x).type() for x in make_test_inputs(p, m, r, n, rng)]"
105+
]
106+
},
107+
{
108+
"cell_type": "code",
109+
"execution_count": 6,
110+
"id": "511de29f",
111+
"metadata": {},
112+
"outputs": [],
113+
"source": [
114+
"kf = StandardFilter()\n",
115+
"kf_outputs = kf.build_graph(*inputs)"
116+
]
117+
},
118+
{
119+
"cell_type": "code",
120+
"execution_count": 7,
121+
"id": "33006d8e",
122+
"metadata": {},
123+
"outputs": [],
124+
"source": [
125+
"batched_inputs = [pt.tensor(shape=(None, *x.type.shape)) for x in inputs]\n",
126+
"vec_subs = dict(zip(inputs, batched_inputs))\n",
127+
"bacthed_kf_outputs = vectorize_graph(kf_outputs, vec_subs)"
128+
]
129+
},
130+
{
131+
"cell_type": "code",
132+
"execution_count": 8,
133+
"id": "987a4647",
134+
"metadata": {},
135+
"outputs": [
136+
{
137+
"data": {
138+
"text/plain": [
139+
"[filtered_states,\n",
140+
" predicted_states,\n",
141+
" observed_states,\n",
142+
" filtered_covariances,\n",
143+
" predicted_covariances,\n",
144+
" observed_covariances,\n",
145+
" loglike_obs]"
146+
]
147+
},
148+
"execution_count": 8,
149+
"metadata": {},
150+
"output_type": "execute_result"
151+
}
152+
],
153+
"source": [
154+
"kf_outputs"
155+
]
156+
},
157+
{
158+
"cell_type": "code",
159+
"execution_count": 9,
160+
"id": "4b8be0f9",
161+
"metadata": {},
162+
"outputs": [],
163+
"source": [
164+
"mu = bacthed_kf_outputs[1]\n",
165+
"cov = bacthed_kf_outputs[4]\n",
166+
"logp = bacthed_kf_outputs[-1]"
167+
]
168+
},
169+
{
170+
"cell_type": "code",
171+
"execution_count": 10,
172+
"id": "1dc80f94",
173+
"metadata": {},
174+
"outputs": [
175+
{
176+
"data": {
177+
"text/plain": [
178+
"(None, 10, 5)"
179+
]
180+
},
181+
"execution_count": 10,
182+
"metadata": {},
183+
"output_type": "execute_result"
184+
}
185+
],
186+
"source": [
187+
"mu.type.shape"
188+
]
189+
},
190+
{
191+
"cell_type": "code",
192+
"execution_count": 20,
193+
"id": "1262c7d4",
194+
"metadata": {},
195+
"outputs": [],
196+
"source": [
197+
"pmss_dist = reload(pmss_dist)"
198+
]
199+
},
200+
{
201+
"cell_type": "code",
202+
"execution_count": 21,
203+
"id": "2dcd3958",
204+
"metadata": {},
205+
"outputs": [
206+
{
207+
"name": "stdout",
208+
"output_type": "stream",
209+
"text": [
210+
"mus_.type.shape: (None, 10, 5), covs_.type.shape: (None, 10, 5, 5)\n",
211+
"mus.type.shape: (10, None, 5), covs.type.shape: (10, None, 5, 5)\n",
212+
"mvn_seq.type.shape: (None, None, 5)\n",
213+
"mvn_seq.type.shape: (None, 10, 5)\n",
214+
"mvn_seq.type.shape: (None, 10, 5)\n",
215+
"mvn_seq.type.shape: (None, 10, 5)\n",
216+
"mus_.type.shape: (None, 10, 5), covs_.type.shape: (None, 10, 5, 5)\n",
217+
"mus.type.shape: (10, None, 5), covs.type.shape: (10, None, 5, 5)\n",
218+
"mvn_seq.type.shape: (None, None, 5)\n",
219+
"mvn_seq.type.shape: (None, 10, 5)\n",
220+
"mvn_seq.type.shape: (None, 10, 5)\n",
221+
"mvn_seq.type.shape: (None, 10, 5)\n"
222+
]
223+
}
224+
],
225+
"source": [
226+
"mv_outputs = pmss_dist.SequenceMvNormal.dist(mus=mu, covs=cov, logp=logp)"
227+
]
228+
},
229+
{
230+
"cell_type": "code",
231+
"execution_count": 22,
232+
"id": "6f41344f",
233+
"metadata": {},
234+
"outputs": [],
235+
"source": [
236+
"np_batch_inputs = create_batch_inputs(3)"
237+
]
238+
},
239+
{
240+
"cell_type": "code",
241+
"execution_count": 23,
242+
"id": "44905b8a",
243+
"metadata": {},
244+
"outputs": [],
245+
"source": [
246+
"np_batch_inputs[0] = rng.normal(size=(3, 10, 1))"
247+
]
248+
},
249+
{
250+
"cell_type": "code",
251+
"execution_count": 24,
252+
"id": "34fe01b8",
253+
"metadata": {},
254+
"outputs": [
255+
{
256+
"data": {
257+
"text/plain": [
258+
"(3, 10, 5)"
259+
]
260+
},
261+
"execution_count": 24,
262+
"metadata": {},
263+
"output_type": "execute_result"
264+
}
265+
],
266+
"source": [
267+
"f_test = pytensor.function(batched_inputs, mv_outputs)\n",
268+
"f_test(*np_batch_inputs).shape"
269+
]
270+
},
271+
{
272+
"cell_type": "code",
273+
"execution_count": 25,
274+
"id": "f37efe79",
275+
"metadata": {},
276+
"outputs": [
277+
{
278+
"name": "stdout",
279+
"output_type": "stream",
280+
"text": [
281+
"(None, 10, 1) (None, 10, 5) (None, 10, 5, 5)\n"
282+
]
283+
}
284+
],
285+
"source": [
286+
"f_mv = pytensor.function(batched_inputs, pm.logp(mv_outputs, batched_inputs[0]))"
287+
]
288+
},
289+
{
290+
"cell_type": "code",
291+
"execution_count": 26,
292+
"id": "7b45de74",
293+
"metadata": {},
294+
"outputs": [
295+
{
296+
"data": {
297+
"text/plain": [
298+
"(3, 10)"
299+
]
300+
},
301+
"execution_count": 26,
302+
"metadata": {},
303+
"output_type": "execute_result"
304+
}
305+
],
306+
"source": [
307+
"f_mv(*np_batch_inputs).shape"
308+
]
309+
},
310+
{
311+
"cell_type": "code",
312+
"execution_count": null,
313+
"id": "f14596aa",
314+
"metadata": {},
315+
"outputs": [],
316+
"source": []
317+
},
318+
{
319+
"cell_type": "code",
320+
"execution_count": 27,
321+
"id": "69519822",
322+
"metadata": {},
323+
"outputs": [],
324+
"source": [
325+
"f = pytensor.function(batched_inputs, bacthed_kf_outputs)"
326+
]
327+
},
328+
{
329+
"cell_type": "code",
330+
"execution_count": 28,
331+
"id": "3f745449",
332+
"metadata": {},
333+
"outputs": [
334+
{
335+
"name": "stdout",
336+
"output_type": "stream",
337+
"text": [
338+
"633 μs ± 18.9 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n",
339+
"1.52 ms ± 35.9 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n",
340+
"4.76 ms ± 259 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
341+
]
342+
}
343+
],
344+
"source": [
345+
"for s in [1, 3, 10]:\n",
346+
" np_batch_inputs = create_batch_inputs(s)\n",
347+
" %timeit outputs = f(*np_batch_inputs)"
348+
]
349+
},
350+
{
351+
"cell_type": "code",
352+
"execution_count": null,
353+
"id": "d5fcadef",
354+
"metadata": {},
355+
"outputs": [],
356+
"source": []
357+
},
358+
{
359+
"cell_type": "code",
360+
"execution_count": null,
361+
"id": "c479ff22",
362+
"metadata": {},
363+
"outputs": [],
364+
"source": []
365+
}
366+
],
367+
"metadata": {
368+
"kernelspec": {
369+
"display_name": "pymc-extras-test",
370+
"language": "python",
371+
"name": "python3"
372+
},
373+
"language_info": {
374+
"codemirror_mode": {
375+
"name": "ipython",
376+
"version": 3
377+
},
378+
"file_extension": ".py",
379+
"mimetype": "text/x-python",
380+
"name": "python",
381+
"nbconvert_exporter": "python",
382+
"pygments_lexer": "ipython3",
383+
"version": "3.12.9"
384+
}
385+
},
386+
"nbformat": 4,
387+
"nbformat_minor": 5
388+
}

0 commit comments

Comments
 (0)