diff --git a/notebooks/INLA_JESSE.ipynb b/notebooks/INLA_JESSE.ipynb new file mode 100644 index 00000000..7d5eddf1 --- /dev/null +++ b/notebooks/INLA_JESSE.ipynb @@ -0,0 +1,522 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "ffd6780e-1bfb-42f0-ba6a-055e9ffd1490", + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "5a2819fd-6e01-47c0-88b2-f2b5e4215b9b", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import pymc as pm\n", + "import pytensor.tensor as pt\n", + "\n", + "import pytensor\n", + "\n", + "from pymc.model.fgraph import fgraph_from_model, model_from_fgraph" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "9f8841b1-9da5-43b2-bf64-6bb93c8f4e93", + "metadata": {}, + "outputs": [], + "source": [ + "def marginalize(\n", + " model,\n", + " rvs_to_marginalize,\n", + " Q,\n", + " temp_kwargs,\n", + " minimizer_kwargs={\"method\": \"BFGS\", \"optimizer_kwargs\": {\"tol\": 1e-8}},\n", + "):\n", + " from pymc.model.fgraph import (\n", + " ModelFreeRV,\n", + " ModelValuedVar,\n", + " )\n", + "\n", + " from pymc_extras.model.marginal.graph_analysis import (\n", + " find_conditional_dependent_rvs,\n", + " find_conditional_input_rvs,\n", + " is_conditional_dependent,\n", + " subgraph_batch_dim_connection,\n", + " )\n", + "\n", + " from pymc_extras.model.marginal.marginal_model import (\n", + " _unique,\n", + " collect_shared_vars,\n", + " remove_model_vars,\n", + " )\n", + "\n", + " from pymc_extras.model.marginal.distributions import (\n", + " MarginalLaplaceRV,\n", + " )\n", + "\n", + " from pymc.pytensorf import collect_default_updates\n", + "\n", + " from pytensor.graph import (\n", + " FunctionGraph,\n", + " Variable,\n", + " clone_replace,\n", + " )\n", + "\n", + " fg, memo = fgraph_from_model(model)\n", + " rvs_to_marginalize = [memo[rv] for rv in rvs_to_marginalize]\n", + " toposort = fg.toposort()\n", + "\n", + " for rv_to_marginalize in sorted(\n", + " rvs_to_marginalize,\n", + " key=lambda rv: toposort.index(rv.owner),\n", + " reverse=True,\n", + " ):\n", + " all_rvs = [node.out for node in fg.toposort() if isinstance(node.op, ModelValuedVar)]\n", + "\n", + " dependent_rvs = find_conditional_dependent_rvs(rv_to_marginalize, all_rvs)\n", + " if not dependent_rvs:\n", + " # TODO: This should at most be a warning, not an error\n", + " raise ValueError(f\"No RVs depend on marginalized RV {rv_to_marginalize}\")\n", + "\n", + " # Issue warning for IntervalTransform on dependent RVs\n", + " for dependent_rv in dependent_rvs:\n", + " transform = dependent_rv.owner.op.transform\n", + "\n", + " # if isinstance(transform, IntervalTransform) or (\n", + " # isinstance(transform, Chain)\n", + " # and any(isinstance(tr, IntervalTransform) for tr in transform.transform_list)\n", + " # ):\n", + " # warnings.warn(\n", + " # f\"The transform {transform} for the variable {dependent_rv}, which depends on the \"\n", + " # f\"marginalized {rv_to_marginalize} may no longer work if bounds depended on other variables.\",\n", + " # UserWarning,\n", + " # )\n", + "\n", + " # Check that no deterministics or potentials depend on the rv to marginalize\n", + " for det in model.deterministics:\n", + " if is_conditional_dependent(memo[det], rv_to_marginalize, all_rvs):\n", + " raise NotImplementedError(\n", + " f\"Cannot marginalize {rv_to_marginalize} due to dependent Deterministic {det}\"\n", + " )\n", + " for pot in model.potentials:\n", + " if is_conditional_dependent(memo[pot], rv_to_marginalize, all_rvs):\n", + " raise NotImplementedError(\n", + " f\"Cannot marginalize {rv_to_marginalize} due to dependent Potential {pot}\"\n", + " )\n", + "\n", + " marginalized_rv_input_rvs = find_conditional_input_rvs([rv_to_marginalize], all_rvs)\n", + " other_direct_rv_ancestors = [\n", + " rv\n", + " for rv in find_conditional_input_rvs(dependent_rvs, all_rvs)\n", + " if rv is not rv_to_marginalize\n", + " ]\n", + " input_rvs = _unique((*marginalized_rv_input_rvs, *other_direct_rv_ancestors))\n", + "\n", + " output_rvs = [rv_to_marginalize, *dependent_rvs]\n", + " rng_updates = collect_default_updates(output_rvs, inputs=input_rvs, must_be_shared=False)\n", + " outputs = output_rvs + list(rng_updates.values())\n", + " inputs = input_rvs + list(rng_updates.keys())\n", + " # Add any other shared variable inputs\n", + " inputs += collect_shared_vars(output_rvs, blockers=inputs)\n", + "\n", + " inner_inputs = [inp.clone() for inp in inputs]\n", + " inner_outputs = clone_replace(outputs, replace=dict(zip(inputs, inner_inputs)))\n", + " inner_outputs = remove_model_vars(inner_outputs)\n", + "\n", + " marginalize_constructor = MarginalLaplaceRV\n", + "\n", + " _, _, *dims = rv_to_marginalize.owner.inputs\n", + " marginalization_op = marginalize_constructor(\n", + " inputs=inner_inputs,\n", + " outputs=inner_outputs,\n", + " dims_connections=[\n", + " (None,),\n", + " ], # dependent_rvs_dim_connections, # TODO NOT SURE WHAT THIS IS\n", + " dims=dims,\n", + " Q=Q,\n", + " temp_kwargs=temp_kwargs,\n", + " minimizer_kwargs=minimizer_kwargs,\n", + " )\n", + "\n", + " new_outputs = marginalization_op(*inputs)\n", + " for old_output, new_output in zip(outputs, new_outputs):\n", + " new_output.name = old_output.name\n", + "\n", + " model_replacements = []\n", + " for old_output, new_output in zip(outputs, new_outputs):\n", + " if old_output is rv_to_marginalize or not isinstance(\n", + " old_output.owner.op, ModelValuedVar\n", + " ):\n", + " # Replace the marginalized ModelFreeRV (or non model-variables) themselves\n", + " var_to_replace = old_output\n", + " else:\n", + " # Replace the underlying RV, keeping the same value, transform and dims\n", + " var_to_replace = old_output.owner.inputs[0]\n", + " model_replacements.append((var_to_replace, new_output))\n", + "\n", + " fg.replace_all(model_replacements)\n", + "\n", + " return model_from_fgraph(fg, mutate_fgraph=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "47d6057b-b459-43ee-afdb-32e63cee5e62", + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "cluster3\n", + "\n", + "3\n", + "\n", + "\n", + "cluster10000 x 3\n", + "\n", + "10000 x 3\n", + "\n", + "\n", + "\n", + "mu_param\n", + "\n", + "mu_param\n", + "~\n", + "Multivariate_normal\n", + "\n", + "\n", + "\n", + "x\n", + "\n", + "x\n", + "~\n", + "Multivariate_normal\n", + "\n", + "\n", + "\n", + "mu_param->x\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "y\n", + "\n", + "y\n", + "~\n", + "Multivariate_normal\n", + "\n", + "\n", + "\n", + "x->y\n", + "\n", + "\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "rng = np.random.default_rng(12345)\n", + "n = 10000\n", + "d = 3\n", + "\n", + "mu_mu = np.zeros((d,))\n", + "mu_true = np.ones((d,))\n", + "\n", + "cov = np.diag(np.ones(d))\n", + "Q_val = np.diag(np.ones(d))\n", + "cov_true = np.diag(np.ones(d))\n", + "\n", + "with pm.Model() as model:\n", + " x_mu = pm.MvNormal(\"mu_param\", mu=mu_mu, cov=cov)\n", + "\n", + " x = pm.MvNormal(\"x\", mu=x_mu, tau=Q_val)\n", + "\n", + " y_obs = rng.multivariate_normal(mean=mu_true, cov=cov_true, size=n)\n", + "\n", + " y = pm.MvNormal(\n", + " \"y\",\n", + " mu=x,\n", + " cov=cov,\n", + " observed=y_obs,\n", + " )\n", + "\n", + "pm.model_to_graphviz(model)" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "93dbf1e3-0242-4da8-aa2f-62c5541f14cc", + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "cluster10000 x 3\n", + "\n", + "10000 x 3\n", + "\n", + "\n", + "cluster3\n", + "\n", + "3\n", + "\n", + "\n", + "\n", + "y\n", + "\n", + "y\n", + "~\n", + "MarginalLaplace\n", + "\n", + "\n", + "\n", + "mu_param\n", + "\n", + "mu_param\n", + "~\n", + "Multivariate_normal\n", + "\n", + "\n", + "\n", + "mu_param->y\n", + "\n", + "\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model_marg = marginalize(\n", + " model,\n", + " [x],\n", + " Q=Q_val,\n", + " temp_kwargs=None,\n", + " minimizer_kwargs={\"method\": \"L-BFGS-B\", \"optimizer_kwargs\": {\"tol\": 1e-8}},\n", + ")\n", + "pm.model_to_graphviz(model_marg, var_names=[\"y\"])" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "41d721e1-fdd4-4d58-95b7-bc978c468ef0", + "metadata": {}, + "outputs": [], + "source": [ + "rng = np.random.default_rng(12345)\n", + "mu = rng.random(d)\n", + "\n", + "f_logp = model_marg.compile_logp(profile=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "41b19afc-dc41-4e6a-908f-41f9fd06f188", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "9.58 ms ± 428 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" + ] + } + ], + "source": [ + "%%timeit\n", + "f_logp({\"mu_param\": mu})" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "8440ca77-e7c8-40ba-a106-40fef8b87d34", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Function profiling\n", + "==================\n", + " Message: /home/michaln/git/pymc/pymc/pytensorf.py:942\n", + " Time in 811 calls to Function.__call__: 7.819569e+00s\n", + " Time in Function.vm.__call__: 7.7988263611114235s (99.735%)\n", + " Time in thunks: 7.780319690704346s (99.498%)\n", + " Total compilation time: 6.888381e-01s\n", + " Number of Apply nodes: 29\n", + " PyTensor rewrite time: 3.427114e-01s\n", + " PyTensor validate time: 5.761475e-03s\n", + " PyTensor Linker time (includes C, CUDA code generation/compiling): 0.3434671189970686s\n", + " C-cache preloading 5.671625e-03s\n", + " Import time 0.000000e+00s\n", + " Node make_thunk time 3.368705e-01s\n", + " Node MinimizeOp(method=L-BFGS-B, jac=True, hess=False, hessp=False)([0.2273360 ... .79736546], True, [0.], [[2]], [[[1. 0. 0 ... . 0. 1.]]], [0.5], True, 0.0, [[2]], ExpandDims{axis=0}.0, [[[1. 0. 0 ... . 0. 1.]]]) time 2.206649e-01s\n", + " Node Scan{scan_fn, while_loop=False, inplace=none}(3, [0 1 2], 3, Neg.0, Neg.0) time 1.019020e-01s\n", + " Node Squeeze{axis=0}(CAReduce{Composite{(i0 + sqr(i1))}, axis=1}.0) time 1.282859e-03s\n", + " Node ExpandDims{axis=1}(mu_param) time 1.102447e-03s\n", + " Node ExpandDims{axis=0}(mu_param) time 1.032342e-03s\n", + "\n", + "Time in all call to pytensor.grad() 7.087498e-01s\n", + "Time since pytensor import 120.663s\n", + "Class\n", + "---\n", + "<% time>