diff --git a/notebooks/INLA_JESSE.ipynb b/notebooks/INLA_JESSE.ipynb
new file mode 100644
index 00000000..7d5eddf1
--- /dev/null
+++ b/notebooks/INLA_JESSE.ipynb
@@ -0,0 +1,522 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "id": "ffd6780e-1bfb-42f0-ba6a-055e9ffd1490",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "%load_ext autoreload\n",
+ "%autoreload 2"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "id": "5a2819fd-6e01-47c0-88b2-f2b5e4215b9b",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import numpy as np\n",
+ "import pymc as pm\n",
+ "import pytensor.tensor as pt\n",
+ "\n",
+ "import pytensor\n",
+ "\n",
+ "from pymc.model.fgraph import fgraph_from_model, model_from_fgraph"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "id": "9f8841b1-9da5-43b2-bf64-6bb93c8f4e93",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def marginalize(\n",
+ " model,\n",
+ " rvs_to_marginalize,\n",
+ " Q,\n",
+ " temp_kwargs,\n",
+ " minimizer_kwargs={\"method\": \"BFGS\", \"optimizer_kwargs\": {\"tol\": 1e-8}},\n",
+ "):\n",
+ " from pymc.model.fgraph import (\n",
+ " ModelFreeRV,\n",
+ " ModelValuedVar,\n",
+ " )\n",
+ "\n",
+ " from pymc_extras.model.marginal.graph_analysis import (\n",
+ " find_conditional_dependent_rvs,\n",
+ " find_conditional_input_rvs,\n",
+ " is_conditional_dependent,\n",
+ " subgraph_batch_dim_connection,\n",
+ " )\n",
+ "\n",
+ " from pymc_extras.model.marginal.marginal_model import (\n",
+ " _unique,\n",
+ " collect_shared_vars,\n",
+ " remove_model_vars,\n",
+ " )\n",
+ "\n",
+ " from pymc_extras.model.marginal.distributions import (\n",
+ " MarginalLaplaceRV,\n",
+ " )\n",
+ "\n",
+ " from pymc.pytensorf import collect_default_updates\n",
+ "\n",
+ " from pytensor.graph import (\n",
+ " FunctionGraph,\n",
+ " Variable,\n",
+ " clone_replace,\n",
+ " )\n",
+ "\n",
+ " fg, memo = fgraph_from_model(model)\n",
+ " rvs_to_marginalize = [memo[rv] for rv in rvs_to_marginalize]\n",
+ " toposort = fg.toposort()\n",
+ "\n",
+ " for rv_to_marginalize in sorted(\n",
+ " rvs_to_marginalize,\n",
+ " key=lambda rv: toposort.index(rv.owner),\n",
+ " reverse=True,\n",
+ " ):\n",
+ " all_rvs = [node.out for node in fg.toposort() if isinstance(node.op, ModelValuedVar)]\n",
+ "\n",
+ " dependent_rvs = find_conditional_dependent_rvs(rv_to_marginalize, all_rvs)\n",
+ " if not dependent_rvs:\n",
+ " # TODO: This should at most be a warning, not an error\n",
+ " raise ValueError(f\"No RVs depend on marginalized RV {rv_to_marginalize}\")\n",
+ "\n",
+ " # Issue warning for IntervalTransform on dependent RVs\n",
+ " for dependent_rv in dependent_rvs:\n",
+ " transform = dependent_rv.owner.op.transform\n",
+ "\n",
+ " # if isinstance(transform, IntervalTransform) or (\n",
+ " # isinstance(transform, Chain)\n",
+ " # and any(isinstance(tr, IntervalTransform) for tr in transform.transform_list)\n",
+ " # ):\n",
+ " # warnings.warn(\n",
+ " # f\"The transform {transform} for the variable {dependent_rv}, which depends on the \"\n",
+ " # f\"marginalized {rv_to_marginalize} may no longer work if bounds depended on other variables.\",\n",
+ " # UserWarning,\n",
+ " # )\n",
+ "\n",
+ " # Check that no deterministics or potentials depend on the rv to marginalize\n",
+ " for det in model.deterministics:\n",
+ " if is_conditional_dependent(memo[det], rv_to_marginalize, all_rvs):\n",
+ " raise NotImplementedError(\n",
+ " f\"Cannot marginalize {rv_to_marginalize} due to dependent Deterministic {det}\"\n",
+ " )\n",
+ " for pot in model.potentials:\n",
+ " if is_conditional_dependent(memo[pot], rv_to_marginalize, all_rvs):\n",
+ " raise NotImplementedError(\n",
+ " f\"Cannot marginalize {rv_to_marginalize} due to dependent Potential {pot}\"\n",
+ " )\n",
+ "\n",
+ " marginalized_rv_input_rvs = find_conditional_input_rvs([rv_to_marginalize], all_rvs)\n",
+ " other_direct_rv_ancestors = [\n",
+ " rv\n",
+ " for rv in find_conditional_input_rvs(dependent_rvs, all_rvs)\n",
+ " if rv is not rv_to_marginalize\n",
+ " ]\n",
+ " input_rvs = _unique((*marginalized_rv_input_rvs, *other_direct_rv_ancestors))\n",
+ "\n",
+ " output_rvs = [rv_to_marginalize, *dependent_rvs]\n",
+ " rng_updates = collect_default_updates(output_rvs, inputs=input_rvs, must_be_shared=False)\n",
+ " outputs = output_rvs + list(rng_updates.values())\n",
+ " inputs = input_rvs + list(rng_updates.keys())\n",
+ " # Add any other shared variable inputs\n",
+ " inputs += collect_shared_vars(output_rvs, blockers=inputs)\n",
+ "\n",
+ " inner_inputs = [inp.clone() for inp in inputs]\n",
+ " inner_outputs = clone_replace(outputs, replace=dict(zip(inputs, inner_inputs)))\n",
+ " inner_outputs = remove_model_vars(inner_outputs)\n",
+ "\n",
+ " marginalize_constructor = MarginalLaplaceRV\n",
+ "\n",
+ " _, _, *dims = rv_to_marginalize.owner.inputs\n",
+ " marginalization_op = marginalize_constructor(\n",
+ " inputs=inner_inputs,\n",
+ " outputs=inner_outputs,\n",
+ " dims_connections=[\n",
+ " (None,),\n",
+ " ], # dependent_rvs_dim_connections, # TODO NOT SURE WHAT THIS IS\n",
+ " dims=dims,\n",
+ " Q=Q,\n",
+ " temp_kwargs=temp_kwargs,\n",
+ " minimizer_kwargs=minimizer_kwargs,\n",
+ " )\n",
+ "\n",
+ " new_outputs = marginalization_op(*inputs)\n",
+ " for old_output, new_output in zip(outputs, new_outputs):\n",
+ " new_output.name = old_output.name\n",
+ "\n",
+ " model_replacements = []\n",
+ " for old_output, new_output in zip(outputs, new_outputs):\n",
+ " if old_output is rv_to_marginalize or not isinstance(\n",
+ " old_output.owner.op, ModelValuedVar\n",
+ " ):\n",
+ " # Replace the marginalized ModelFreeRV (or non model-variables) themselves\n",
+ " var_to_replace = old_output\n",
+ " else:\n",
+ " # Replace the underlying RV, keeping the same value, transform and dims\n",
+ " var_to_replace = old_output.owner.inputs[0]\n",
+ " model_replacements.append((var_to_replace, new_output))\n",
+ "\n",
+ " fg.replace_all(model_replacements)\n",
+ "\n",
+ " return model_from_fgraph(fg, mutate_fgraph=True)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "id": "47d6057b-b459-43ee-afdb-32e63cee5e62",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "image/svg+xml": [
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 5,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "rng = np.random.default_rng(12345)\n",
+ "n = 10000\n",
+ "d = 3\n",
+ "\n",
+ "mu_mu = np.zeros((d,))\n",
+ "mu_true = np.ones((d,))\n",
+ "\n",
+ "cov = np.diag(np.ones(d))\n",
+ "Q_val = np.diag(np.ones(d))\n",
+ "cov_true = np.diag(np.ones(d))\n",
+ "\n",
+ "with pm.Model() as model:\n",
+ " x_mu = pm.MvNormal(\"mu_param\", mu=mu_mu, cov=cov)\n",
+ "\n",
+ " x = pm.MvNormal(\"x\", mu=x_mu, tau=Q_val)\n",
+ "\n",
+ " y_obs = rng.multivariate_normal(mean=mu_true, cov=cov_true, size=n)\n",
+ "\n",
+ " y = pm.MvNormal(\n",
+ " \"y\",\n",
+ " mu=x,\n",
+ " cov=cov,\n",
+ " observed=y_obs,\n",
+ " )\n",
+ "\n",
+ "pm.model_to_graphviz(model)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "id": "93dbf1e3-0242-4da8-aa2f-62c5541f14cc",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "image/svg+xml": [
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 14,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "model_marg = marginalize(\n",
+ " model,\n",
+ " [x],\n",
+ " Q=Q_val,\n",
+ " temp_kwargs=None,\n",
+ " minimizer_kwargs={\"method\": \"L-BFGS-B\", \"optimizer_kwargs\": {\"tol\": 1e-8}},\n",
+ ")\n",
+ "pm.model_to_graphviz(model_marg, var_names=[\"y\"])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "id": "41d721e1-fdd4-4d58-95b7-bc978c468ef0",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "rng = np.random.default_rng(12345)\n",
+ "mu = rng.random(d)\n",
+ "\n",
+ "f_logp = model_marg.compile_logp(profile=True)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "id": "41b19afc-dc41-4e6a-908f-41f9fd06f188",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "9.58 ms ± 428 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
+ ]
+ }
+ ],
+ "source": [
+ "%%timeit\n",
+ "f_logp({\"mu_param\": mu})"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 17,
+ "id": "8440ca77-e7c8-40ba-a106-40fef8b87d34",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Function profiling\n",
+ "==================\n",
+ " Message: /home/michaln/git/pymc/pymc/pytensorf.py:942\n",
+ " Time in 811 calls to Function.__call__: 7.819569e+00s\n",
+ " Time in Function.vm.__call__: 7.7988263611114235s (99.735%)\n",
+ " Time in thunks: 7.780319690704346s (99.498%)\n",
+ " Total compilation time: 6.888381e-01s\n",
+ " Number of Apply nodes: 29\n",
+ " PyTensor rewrite time: 3.427114e-01s\n",
+ " PyTensor validate time: 5.761475e-03s\n",
+ " PyTensor Linker time (includes C, CUDA code generation/compiling): 0.3434671189970686s\n",
+ " C-cache preloading 5.671625e-03s\n",
+ " Import time 0.000000e+00s\n",
+ " Node make_thunk time 3.368705e-01s\n",
+ " Node MinimizeOp(method=L-BFGS-B, jac=True, hess=False, hessp=False)([0.2273360 ... .79736546], True, [0.], [[2]], [[[1. 0. 0 ... . 0. 1.]]], [0.5], True, 0.0, [[2]], ExpandDims{axis=0}.0, [[[1. 0. 0 ... . 0. 1.]]]) time 2.206649e-01s\n",
+ " Node Scan{scan_fn, while_loop=False, inplace=none}(3, [0 1 2], 3, Neg.0, Neg.0) time 1.019020e-01s\n",
+ " Node Squeeze{axis=0}(CAReduce{Composite{(i0 + sqr(i1))}, axis=1}.0) time 1.282859e-03s\n",
+ " Node ExpandDims{axis=1}(mu_param) time 1.102447e-03s\n",
+ " Node ExpandDims{axis=0}(mu_param) time 1.032342e-03s\n",
+ "\n",
+ "Time in all call to pytensor.grad() 7.087498e-01s\n",
+ "Time since pytensor import 120.663s\n",
+ "Class\n",
+ "---\n",
+ "<% time>