diff --git a/notebooks/Kalman_Filter_Gradient.ipynb b/notebooks/Kalman_Filter_Gradient.ipynb new file mode 100644 index 00000000..ff31dbbb --- /dev/null +++ b/notebooks/Kalman_Filter_Gradient.ipynb @@ -0,0 +1,1058 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "69ae14a1", + "metadata": {}, + "source": [ + "### Imports" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "90979a41", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import pandas as pd\n", + "import pytensor\n", + "import pytensor.tensor as pt\n", + "import matplotlib.pyplot as plt\n", + "from pytensor.compile.builders import OpFromGraph\n", + "from time import perf_counter\n", + "from collections import defaultdict\n", + "import pymc_extras as pmx\n", + "from pymc_extras.statespace import structural as sts\n", + "import pytensor\n", + "from pytensor.graph.basic import explicit_graph_inputs\n", + "import numpy as np" + ] + }, + { + "cell_type": "markdown", + "id": "a0d008fc", + "metadata": {}, + "source": [ + "### Generate a random dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "fdb156d6", + "metadata": {}, + "outputs": [], + "source": [ + "mod = (\n", + " sts.LevelTrendComponent(order=2, innovations_order=[0, 1], name='level') +\n", + " sts.AutoregressiveComponent(order=1, name='ar') +\n", + " sts.MeasurementError(name='obs_error')\n", + ").build(verbose = False)\n", + "\n", + "param_values = {\n", + " 'initial_level': np.array([10, 0.1]),\n", + " 'sigma_level': np.array([1e-2]),\n", + " 'params_ar': np.array([0.95]),\n", + " 'sigma_ar': np.array(1e-2),\n", + " 'sigma_obs_error': np.array(1e-2),\n", + "}\n", + "\n", + "data_fn = pmx.statespace.compile_statespace(mod, steps=100)\n", + "hidden_state_data, obs_data = data_fn(**param_values)\n", + "\n", + "matrices = mod._unpack_statespace_with_placeholders()\n", + "\n", + "matrix_fn = pytensor.function(list(explicit_graph_inputs(matrices)),\n", + " matrices)\n", + "a0, P0, c, d, T, Z, R, H, Q = matrix_fn(**param_values, initial_state_cov=np.eye(mod.k_states))" + ] + }, + { + "cell_type": "markdown", + "id": "51b7e885", + "metadata": {}, + "source": [ + "### Symbolic variable" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "3661408d", + "metadata": {}, + "outputs": [], + "source": [ + "# Paramètres symboliques\n", + "A_sym = pt.matrix(\"A\") # (n, n)\n", + "H_sym = pt.matrix(\"H\") # (n, n)\n", + "Q_sym = pt.matrix(\"Q\") # (n, n)\n", + "R_sym = pt.matrix(\"R\") # (n, n)\n", + "T_sym = pt.matrix(\"T\") # (n, n)\n", + "Z_sym = pt.matrix(\"Z\") # (n, n)\n", + "y_sym = pt.matrix(\"y\") # (T, n) : observations\n", + "\n", + "a0_sym = pt.vector(\"a0\") # (n,) \n", + "P0_sym = pt.matrix(\"P0\") # (n, n)\n", + "\n", + "data_sym = pt.matrix('data_sym') # [T, obs_dim]" + ] + }, + { + "cell_type": "markdown", + "id": "19e6a32d", + "metadata": {}, + "source": [ + "## Kalman filter with classic gradient" + ] + }, + { + "cell_type": "markdown", + "id": "4fb4cef1", + "metadata": {}, + "source": [ + "### The Loss\n", + "\n", + "The Negative Log-Likelihood loss os given in the paper as the following expression :\n", + "\n", + "$$\n", + "L_{NLL} = \\sum l_{n|n} + l_{n|n-1}\n", + "$$\n", + "\n", + "Where :\n", + "\n", + "$$\n", + "\\begin{align}\n", + "&l_{n|n} = 0 \\\\\n", + "&l_{n|n-1} = log det(F) + v_n^TFv_n\n", + "\\end{align}\n", + "$$" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "35351096", + "metadata": {}, + "outputs": [], + "source": [ + "def predict(a, P, T, Q):\n", + " a_hat = T @ a # x_n|n-1\n", + " P_hat = T @ P @ T.T + Q # P_n|n-1\n", + " return a_hat, P_hat\n", + "\n", + "def update(y, a, P, Z, H):\n", + " v = y - Z.dot(a) # z_n\n", + " PZT = P.dot(Z.T) \n", + "\n", + " F = Z.dot(PZT) + H # S_n\n", + " F_inv = pt.linalg.inv(F) # S_n^(-1)\n", + " K = PZT.dot(F_inv) # K_n\n", + "\n", + " I_KZ = pt.eye(a.shape[0]) - K.dot(Z)\n", + " a_filtered = a + K.dot(v) # x_n|n\n", + " P_filtered = I_KZ @ P # P_n|n\n", + "\n", + " inner_term = v.T @ F_inv @ v\n", + " _, F_logdet = pt.linalg.slogdet(F) # log det S_n\n", + " ll = (F_logdet + inner_term).ravel()[0] # Loss\n", + "\n", + " return [a_filtered, P_filtered, Z.dot(a), F, ll]\n", + "\n", + "def kalman_step(y, a, P, T, Z, H, Q):\n", + " a_filtered, P_filtered, obs_mu, obs_cov, ll = update(y=y, a=a, P=P, Z=Z, H=H)\n", + " a_hat, P_hat = predict(a=a_filtered, P=P_filtered, T=T, Q=Q)\n", + " return [a_filtered, a_hat, obs_mu, P_filtered, P_hat, obs_cov, ll]\n", + "\n", + "\n", + "outputs_info = [None, a0_sym, None, None, P0_sym, None, None]\n", + "\n", + "results_seq, updates = pytensor.scan(\n", + " kalman_step,\n", + " sequences=[data_sym],\n", + " outputs_info=outputs_info,\n", + " non_sequences=[T_sym, Z_sym, H_sym, Q_sym],\n", + " strict=False,\n", + ")\n", + "\n", + "# --- Loss ---\n", + "a_upd_seq, a_pred_seq, y_hat_seq, P_upd_seq, P_pred_seq, obs_cov, ll_seq = results_seq\n", + "loss = pt.sum(ll_seq)" + ] + }, + { + "cell_type": "markdown", + "id": "ece2f47e", + "metadata": {}, + "source": [ + "## Custom gradient" + ] + }, + { + "cell_type": "markdown", + "id": "5dc91ae7", + "metadata": {}, + "source": [ + "### Gradient with respect to **$a_{n-1|n-1}$**\n", + "\n", + "From the article we have :\n", + "\n", + "$$\n", + "\\begin{align}\n", + "&\\frac{dL}{da_{n-1|n-1}} = T_n^T \\frac{dL}{da_{n|n-1}} \n", + "+ \\frac{dl_{n-1|n-1}}{da_{n-1|n-1}} \\quad &\\text{(equation 22)} \\\\\n", + "&\\frac{dl_{n|n}}{da_{n|n}} = 0 \\quad &\\text{(equation 28)}\n", + "\\end{align}\n", + "$$\n", + "\n", + "Givent this two equations, we now have :\n", + "$$\n", + "\\begin{align}\n", + "&\\frac{dL}{da_{n-1|n-1}} = T_n^T \\frac{dL}{da_{n|n-1}} \n", + "\\end{align}\n", + "$$\n", + "\n", + "### Gradient with respect to **$P_{n-1|n-1}$**\n", + "\n", + "From the article we have :\n", + "\n", + "$$\n", + "\\begin{align}\n", + "&\\frac{dL}{dP_{n-1|n-1}} = T_n^T \\frac{dL}{dP_{n|n-1}} T_n\n", + "+ \\frac{dl_{n-1|n-1}}{dP_{n-1|n-1}} \\quad &\\text{(equation 23)} \\\\\n", + "&\\frac{dl_{n|n}}{dP_{n|n}} = 0 \\quad &\\text{(equation 28)}\n", + "\\end{align}\n", + "$$\n", + "\n", + "Givent this two equations, we now have :\n", + "$$\n", + "\\begin{align}\n", + "&\\frac{dL}{dP_{n-1|n-1}} = T_n^T \\frac{dL}{dP_{n|n-1}} T_n\n", + "\\end{align}\n", + "$$\n" + ] + }, + { + "cell_type": "markdown", + "id": "22a3560b", + "metadata": {}, + "source": [ + "### Gradient with respect to **$a_{n|n-1}$**\n", + "\n", + "From the article we have :\n", + "\n", + "$$\n", + "\\begin{align}\n", + "&\\frac{dL}{da_{n|n-1}} = (I - K_n Z_n)^T \\frac{dL}{da_{n|n}} + \\frac{dl_{n|n-1}}{da_{n|n-1}} \\quad &\\text{(equation 20)} \\\\\n", + "&\\frac{dl_{n|n-1}}{da_{n|n-1}} = -2 Z_n^{T}F_n^{-1} v_n \\quad &\\text{(equation 30)} \\\\\n", + "&\\frac{dL}{da_{n|n}} = T_n^T \\frac{dL}{da_{n+1|n}} \\quad &\\text{see gradient with respect to} \\quad a_{n|n}\n", + "\\end{align}\n", + "$$\n", + "\n", + "Givent this two equations, we now have :\n", + "$$\n", + "\\begin{align}\n", + "&\\frac{dL}{da_{n|n-1}} = (I - K_n Z_n)^T T_n^T \\frac{dL}{da_{n+1|n}} - 2 Z_n^{T}F^{-1} v_n\n", + "\\end{align}\n", + "$$\n" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "ee21ef4e", + "metadata": {}, + "outputs": [], + "source": [ + "def grad_a_hat(inp, out, out_grad):\n", + " y, a, P, T, Z, H, Q = inp\n", + " a_hat_grad, _, _ = out_grad\n", + "\n", + " v = y - Z.dot(a) \n", + "\n", + " PZT = P.dot(Z.T)\n", + " F = Z.dot(PZT) + H \n", + " F_inv = pt.linalg.inv(F)\n", + " \n", + " K = PZT.dot(F_inv) \n", + " I_KZ = pt.eye(a.shape[0]) - K.dot(Z)\n", + "\n", + " grad_a_pred = I_KZ.T @ T.T @ a_hat_grad - 2 * Z.T @ F_inv @ v\n", + "\n", + " return grad_a_pred" + ] + }, + { + "cell_type": "markdown", + "id": "293d8d65", + "metadata": {}, + "source": [ + "### Gradient with respect to **$P_{n|n-1}$**\n", + "\n", + "From the article we have :\n", + "\n", + "$$\n", + "\\begin{align}\n", + "&\\frac{dL}{dP_{n|n-1}} = (I - K_n Z_n)^T [\n", + " \\frac{dL}{dP_{n|n}}\n", + " + \\frac{1}{2} \\frac{dL}{da_{n|n}} v_n^T H_n^-1 Z_n\n", + " + \\frac{1}{2} Z_n^T R_n^{-1} v_n (\\frac{dL}{da_{n|n}})^T\n", + " ](I - K_n Z_n) \n", + " + \\frac{dl{n|n-1}}{dP_{n|n-1}} \\quad &\\text{(equation 21)} \\\\\n", + "&\\frac{dl_{n|n-1}}{dP_{n|n-1}} = Z_n^T F_n^{-1} Z_n - Z_n^T F_n^-1 v_n v_n^T F_n^{-1} Z_n \\quad &\\text{(equation 29)} \\\\\n", + "&\\frac{dL}{da_{n|n}} = T_n^T \\frac{dL}{da_{n+1|n}} \\quad &\\text{see gradient with respect to} \\quad a_{n|n} \\\\\n", + "&\\frac{dL}{dP_{n|n}} = T_n^T \\frac{dL}{dP_{n+1|n}} T_n \\quad &\\text{see gradient with respect to} \\quad P_{n|n}\n", + "\\end{align}\n", + "$$\n", + "\n", + "Givent this two equations, we now have :\n", + "$$\n", + "\\begin{align}\n", + "&\\frac{dL}{dP_{n|n-1}} = (I - K_n Z_n)^T [\n", + " T_n^T \\frac{dL}{dP_{n+1|n}} T_n\n", + " + \\frac{1}{2} T_n^T \\frac{dL}{da_{n+1|n}} v_n^T H_n^{-1} Z_n\n", + " + \\frac{1}{2} Z_n^T H_n^{-1} v_n (T_n^T \\frac{dL}{da_{n+1|n}})^T\n", + " ](I - K_n Z_n) \n", + " + Z_n^T F_n^{-1} Z_n \n", + " - Z_n^T F_n^{-1} v_n v_n^T F_n^{-1} Z_n\n", + "\\end{align}\n", + "$$" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "8c89b018", + "metadata": {}, + "outputs": [], + "source": [ + "def grad_P_hat(inp, out, out_grad):\n", + " y, a, P, T, Z, H, Q = inp\n", + " a_hat_grad, P_hat_grad, ll_grad = out_grad\n", + "\n", + " v = y - Z.dot(a)\n", + " v = v.dimshuffle(0, 'x')\n", + " a_hat_grad = a_hat_grad.dimshuffle(0, 'x') \n", + "\n", + " P_filtered_grad = T.T @ P_hat_grad @ T\n", + " a_filtered_grad = T.T @ a_hat_grad \n", + "\n", + " PZT = P.dot(Z.T)\n", + " F = Z.dot(PZT) + H\n", + "\n", + " H_inv = pt.linalg.inv(H) \n", + " F_inv = pt.linalg.inv(F)\n", + " \n", + " K = PZT.dot(F_inv) \n", + " I_KZ = pt.eye(a.shape[0]) - K.dot(Z)\n", + "\n", + " grad_P_hat = I_KZ.T @ ( P_filtered_grad + 0.5 * a_filtered_grad @ v.T @ H_inv @ Z + 0.5 * Z.T @ H_inv @ v @ a_filtered_grad.T ) @ I_KZ + Z.T @ F_inv @ Z - Z.T @ F_inv @ v @ v.T @ F_inv @ Z\n", + "\n", + " return grad_P_hat" + ] + }, + { + "cell_type": "markdown", + "id": "f0f2dce4", + "metadata": {}, + "source": [ + "### Gradient with respect to **y**\n", + "\n", + "From the article we have :\n", + "\n", + "$$\n", + "\\begin{align}\n", + "&\\frac{dL}{dy_n} = K_n^T\\frac{dL}{da_{n|n}} + \\frac{dl_{n|n-1}}{dy_n} \\quad &\\text{(equation 24)} \\\\\n", + "&\\frac{dl_{n|n-1}}{dy_n} = 2F^{-1}v_n \\quad &\\text{(equation 31)} \\\\\n", + "&\\frac{dL}{da_{n|n}} = T_n^T \\frac{dL}{da_{n+1|n}} \\quad &\\text{see gradient with respect to} \\quad a_{n|n} \\\\\n", + "\\end{align}\n", + "$$\n", + "\n", + "Givent this two equations, we now have :\n", + "$$\n", + "\\begin{align}\n", + "&\\frac{dL}{dy_n} = K_n^TT_n^T\\frac{dL}{da_{n+1|n}} + 2F^{-1}v_n\n", + "\\end{align}\n", + "$$\n" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "bba53a26", + "metadata": {}, + "outputs": [], + "source": [ + "def grad_y(inp, out, out_grad):\n", + " y, a, P, T, Z, H, Q = inp\n", + " a_hat_grad, P_h_grad, y_grad = out_grad\n", + "\n", + " y_hat = Z.dot(a)\n", + " v = y - y_hat\n", + "\n", + " PZT = P.dot(Z.T)\n", + " F = Z.dot(PZT) + H\n", + " F_inv = pt.linalg.inv(F)\n", + "\n", + " K = PZT.dot(F_inv) \n", + " \n", + " return K.T @ T.T @ a_hat_grad + 2 * F_inv @ v" + ] + }, + { + "cell_type": "markdown", + "id": "d6b48789", + "metadata": {}, + "source": [ + "### Gradient with respect to Q\n", + "\n", + "From the article we have :\n", + "\n", + "$$\n", + "\\begin{align}\n", + "\\frac{dL}{dQ_n} = \\frac{dL}{dP_{n|n-1}} & \\quad \\text{(equation 25)}\n", + "\\end{align}\n", + "$$" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "c17949b7", + "metadata": {}, + "outputs": [], + "source": [ + "def grad_Q(inp, out, out_grad):\n", + " _, P_h_grad, _ = out_grad\n", + " return P_h_grad" + ] + }, + { + "cell_type": "markdown", + "id": "f0bc0287", + "metadata": {}, + "source": [ + "### Gradient with respect to **H**\n", + "\n", + "From the article we have :\n", + "\n", + "$$\n", + "\\begin{align}\n", + "&\\frac{dL}{dH_n} = K_n^T\\frac{dL}{dP_{n|n}}K_n \n", + "- \\frac{1}{2} K_n^T \\frac{dL}{da_{n|n}} v_n^T F^{-1}\n", + "- \\frac{1}{2} S_n^{-1} v_n (\\frac{dL}{da_{n|n}})^T K_n\n", + "+ \\frac{dl_{n|n-1}}{dH_n} \n", + "\\quad &\\text{(equation 26)} \\\\\n", + "&\\frac{dl_{n|n-1}}{dH_n} = F^{-1} - F_n^{-1} v_n v_n^T F_n^{-1} \n", + "\\quad &\\text{(equation 31)} \\\\\n", + "&\\frac{dL}{da_{n|n}} = T_n^T \\frac{dL}{da_{n+1|n}} \\quad &\\text{see gradient with respect to} \\quad a_{n|n} \\\\\n", + "&\\frac{dL}{dP_{n|n}} = T_n^T \\frac{dL}{dP_{n+1|n}} T_n \\quad &\\text{see gradient with respect to} \\quad P_{n|n}\n", + "\\end{align}\n", + "$$\n", + "\n", + "Givent this two equations, we now have :\n", + "$$\n", + "\\begin{align}\n", + "&\\frac{dL}{dH_n} = K_n^T T_n^T \\frac{dL}{dP_{n+1|n}} T_n K_n \n", + "- \\frac{1}{2} K_n^T T_n^T \\frac{dL}{da_{n+1|n}} v_n^T F^{-1}\n", + "- \\frac{1}{2} F_n^{-1} v_n (T_n^T \\frac{dL}{da_{n+1|n}})^T K_n\n", + "+ F^{-1} - F_n^{-1} v_n v_n^T F_n^{-1}\n", + "\\end{align}\n", + "$$\n" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "84cb6867", + "metadata": {}, + "outputs": [], + "source": [ + "def grad_H(inp, out, out_grad):\n", + " y, a, P, T, Z, H, Q = inp\n", + " a_hat_grad, P_h_grad, y_grad = out_grad\n", + " \n", + " y_hat = Z.dot(a)\n", + " v = y - y_hat\n", + "\n", + " PZT = P.dot(Z.T)\n", + " F = Z.dot(PZT) + H\n", + " F_inv = pt.linalg.inv(F)\n", + "\n", + " K = PZT.dot(F_inv)\n", + "\n", + " v = v.dimshuffle(0, 'x')\n", + " a_hat_grad = a_hat_grad.dimshuffle(0, 'x') \n", + "\n", + " a_filtered_grad = T.T @ a_hat_grad\n", + " P_filtered_grad = T.T @ P_h_grad @ T\n", + "\n", + " return K.T @ P_filtered_grad @ K - 0.5 * K.T @ a_filtered_grad @ v.T @ F_inv - 0.5 * F_inv @ v @ a_filtered_grad.T @ K + F_inv - F_inv @ v @ v.T @ F_inv" + ] + }, + { + "cell_type": "markdown", + "id": "4fa2ffc0", + "metadata": {}, + "source": [ + "### Gradient with respect to **T**\n", + "\n", + "This gradient was not given in the article. Here are the steps that got me to this expression :\n", + "\n", + "1 - Only $x_{n|n-1}$ and $P_{n|n-1}$ depends on $T_n$. Hence :\n", + "$$\n", + "\\frac{\\partial L}{\\partial T} = \\frac{\\partial L}{\\partial x_{n|n-1}} \\frac{\\partial x_{n|n-1}}{\\partial T} + \\frac{\\partial L}{\\partial P_{n|n-1}} \\frac{\\partial T}{\\partial P_{n|n-1}}\n", + "$$\n", + "2 - Using the equation (11) and (12) of the article, on the (1), we directly got that :\n", + "$$\n", + "\\frac{\\partial L}{\\partial x_{n|n-1}} \\frac{\\partial x_{n|n-1}}{\\partial T} = \\frac{\\partial L}{\\partial x_{n|n-1}} x_{n-1|n-1}^T\n", + "$$\n", + "3 - Recognizing the first quadratic form in the equation (2), and using equation (11) we got :\n", + "$$\n", + "\\frac{\\partial L}{\\partial P_{n|n-1}} \\frac{\\partial P_{n|n-1}}{\\partial T^T} = P_{n|n-1}T_n^T \\frac{\\partial L}{\\partial P_{n|n-1}}^T + P_{n|n-1}^T T_n^T \\frac{\\partial L}{\\partial P_{n|n-1}}\n", + "$$\n", + "4 - Now transposing to get the dependencies on T :\n", + "$$\n", + "\\frac{\\partial L}{\\partial P_{n|n-1}} \\frac{\\partial P_{n|n-1}}{\\partial T} = \\frac{\\partial L}{\\partial P_{n|n-1}} T_n P_{n|n-1}^T +\\frac{\\partial L}{\\partial P_{n|n-1}}^T T_n P_{n|n-1}\n", + "$$\n", + "5 - Finally, we have :\n", + "$$\n", + "\\frac{\\partial L}{\\partial T} = \\frac{\\partial L}{\\partial x_{n|n-1}} x_{n-1|n-1}^T + \\frac{\\partial L}{\\partial P_{n|n-1}} T_n P_{n|n-1}^T +\\frac{\\partial L}{\\partial P_{n|n-1}}^T T_n P_{n|n-1}\n", + "$$" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "9a560ed9", + "metadata": {}, + "outputs": [], + "source": [ + "def grad_T(inp, out, out_grad):\n", + " y, a, P, T, Z, H, Q = inp\n", + " a_hat_grad, P_h_grad, y_grad = out_grad\n", + "\n", + " y_hat = Z.dot(a)\n", + " v = y - y_hat\n", + "\n", + " PZT = P.dot(Z.T)\n", + " F = Z.dot(PZT) + H\n", + " F_inv = pt.linalg.inv(F)\n", + "\n", + " K = PZT.dot(F_inv)\n", + " I_KZ = pt.eye(a.shape[0]) - K.dot(Z)\n", + "\n", + " v = v.dimshuffle(0, 'x')\n", + " a = a.dimshuffle(0, 'x')\n", + " a_hat_grad = a_hat_grad.dimshuffle(0, 'x')\n", + "\n", + " a_filtered = a + K.dot(v)\n", + " P_filtered = I_KZ @ P\n", + "\n", + " return a_hat_grad @ a_filtered.T + P_h_grad @ T @ P_filtered.T + P_h_grad.T @ T @ P_filtered" + ] + }, + { + "cell_type": "markdown", + "id": "bd458dee", + "metadata": {}, + "source": [ + "### Total grad" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "afb362e5", + "metadata": {}, + "outputs": [], + "source": [ + "def custom_grad(inp, out, out_grad):\n", + " y, a, P, T, Z, H, Q = inp\n", + " a_filtered, P_filtered, y_hat = out\n", + " a_hat_grad, P_hat_grad, y_grad = out_grad\n", + "\n", + " PZT = P.dot(Z.T)\n", + " F = Z.dot(PZT) + H\n", + "\n", + " y_hat = Z.dot(a)\n", + " v = y - y_hat\n", + "\n", + " H_inv = pt.linalg.inv(H)\n", + " F_inv = pt.linalg.inv(F)\n", + "\n", + " K = PZT.dot(F_inv)\n", + " I_KZ = pt.eye(a.shape[0]) - K.dot(Z)\n", + " \n", + " grad_a_pred = I_KZ.T @ T.T @ a_hat_grad - 2 * Z.T @ F_inv @ v\n", + " grad_y = K.T @ T.T @ a_hat_grad + 2 * F_inv @ v\n", + "\n", + "\n", + " a_hat_grad = a_hat_grad.dimshuffle(0, 'x')\n", + " v = v.dimshuffle(0, 'x')\n", + " \n", + " P_filtered_grad = T.T @ P_hat_grad @ T\n", + " a_filtered_grad = T.T @ a_hat_grad \n", + "\n", + " grad_P_hat = I_KZ.T @ ( P_filtered_grad + 0.5 * a_filtered_grad @ v.T @ H_inv @ Z + 0.5 * Z.T @ H_inv @ v @ a_filtered_grad.T ) @ I_KZ + Z.T @ F_inv @ Z - Z.T @ F_inv @ v @ v.T @ F_inv @ Z\n", + " grad_Z = None\n", + " grad_T = None\n", + " grad_Q = P_hat_grad\n", + " grad_H = K.T @ P_filtered_grad @ K - 0.5 * K.T @ a_filtered_grad @ v.T @ F_inv - 0.5 * F_inv @ v @ a_filtered_grad.T @ K + F_inv - F_inv @ v @ v.T @ F_inv\n", + "\n", + " return [grad_P_hat,\n", + " grad_a_pred,\n", + " grad_y,\n", + " grad_Z,\n", + " grad_T,\n", + " grad_Q,\n", + " grad_H]\n" + ] + }, + { + "cell_type": "markdown", + "id": "607753a1", + "metadata": {}, + "source": [ + "## Custom Kalman Filter" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "7cead2c1", + "metadata": {}, + "outputs": [], + "source": [ + "y_sym = pt.vector(\"y\")\n", + "\n", + "kalman_step_op = OpFromGraph(\n", + " inputs=[y_sym, a0_sym, P0_sym, T_sym, Z_sym, H_sym, Q_sym],\n", + " outputs=kalman_step(y_sym, a0_sym, P0_sym, T_sym, Z_sym, H_sym, Q_sym),\n", + " lop_overrides=[grad_y, grad_a_hat, grad_P_hat, grad_T, None, grad_H, grad_Q],\n", + " inline=True\n", + ")\n", + "\n", + "outputs_info = [None, a0_sym, None, None, P0_sym, None, None]\n", + "\n", + "results_op, updates = pytensor.scan(\n", + " kalman_step_op,\n", + " sequences=[data_sym],\n", + " outputs_info=outputs_info,\n", + " non_sequences=[T_sym, Z_sym, H_sym, Q_sym],\n", + " strict=False,\n", + ")\n", + "# --- Loss ---\n", + "a_upd_op, a_pred_op, y_hat_op, P_upd_op, P_pred_op, obs_cov, ll_op = results_op\n", + "loss_op = pt.sum(ll_op)" + ] + }, + { + "cell_type": "markdown", + "id": "3f79c5c6", + "metadata": {}, + "source": [ + "## Handmade Numpy Backpropagation " + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "b6eb5d48", + "metadata": {}, + "outputs": [], + "source": [ + "def compute_grad_a0(observations, a0, P0, a_pred_seq, P_pred_seq, Z, H, T):\n", + " # Constant\n", + " SHAPE_a0 = a0.shape[0]\n", + " NB_obs = len(observations)\n", + "\n", + " # Initialisation for the backprop\n", + " PZT = P_pred_seq[-2].dot(Z.T)\n", + " F = Z.dot(PZT) + H\n", + " F_inv = np.linalg.solve(F, np.eye(F.shape[0]))\n", + " \n", + " grad = [0 for _ in range(NB_obs)]\n", + " grad[-1] = - 2 * Z.T @ F_inv @ (observations[-1] - Z @ a_pred_seq[-2])\n", + "\n", + " # Backprop\n", + " for i in range(3, NB_obs+1):\n", + "\n", + " PZT = P_pred_seq[-i].dot(Z.T)\n", + " F = Z.dot(PZT) + H\n", + " F_inv = np.linalg.solve(F, np.eye(F.shape[0]))\n", + "\n", + " K = PZT.dot(F_inv)\n", + " I_KZ = np.eye(SHAPE_a0) - K.dot(Z)\n", + "\n", + " grad[1-i] = I_KZ.T @ T.T @ grad[2-i] - (2 * Z.T @ F_inv @ (observations[1-i] - Z @ a_pred_seq[-i])).T \n", + "\n", + " # Last iter with a0/P0\n", + " PZT = P0.dot(Z.T)\n", + " F = Z.dot(PZT) + H\n", + " F_inv = np.linalg.solve(F, np.eye(F.shape[0]))\n", + "\n", + " K = PZT.dot(F_inv)\n", + " I_KZ = np.eye(SHAPE_a0) - K.dot(Z)\n", + "\n", + " grad[0] = I_KZ.T @ T.T @ grad[1] - (2 * Z.T @ F_inv @ (observations[0] - Z @ a0)).T\n", + "\n", + " return grad" + ] + }, + { + "cell_type": "markdown", + "id": "f0575c2c", + "metadata": {}, + "source": [ + "## Speed observation" + ] + }, + { + "cell_type": "markdown", + "id": "c99fddf9", + "metadata": {}, + "source": [ + "### Benchmark for pytensor computed gradients" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "908946b0", + "metadata": {}, + "outputs": [], + "source": [ + "def benchmark_kalman_gradients(loss, obs_data, a0, P0, T, Z, R, H, Q):\n", + " results = defaultdict(dict)\n", + " exec_time = 0\n", + "\n", + " grad_list = pt.grad(loss, [a0_sym])\n", + " f_grad = pytensor.function(\n", + " inputs=[data_sym, a0_sym, P0_sym, T_sym, Z_sym, H_sym, Q_sym],\n", + " outputs=grad_list,\n", + " )\n", + "\n", + " for _ in range(20):\n", + " \n", + " # --- exécution ---\n", + " t0 = perf_counter()\n", + " _ = f_grad(\n", + " obs_data[:, np.newaxis],\n", + " a0,\n", + " P0,\n", + " T,\n", + " Z,\n", + " H,\n", + " R @ Q @ R.T,\n", + " )\n", + " t1 = perf_counter()\n", + " exec_time += (t1 - t0)/20\n", + " \n", + " \n", + " results[\"exec_time\"] = exec_time\n", + "\n", + " return results" + ] + }, + { + "cell_type": "markdown", + "id": "f03a7555", + "metadata": {}, + "source": [ + "### Benchmark for numpy computed gradient" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "a85fe92e", + "metadata": {}, + "outputs": [], + "source": [ + "def benchmark_kalman_gradients_np(a_pred_seq, P_pred_seq, obs_data, a0, P0, T, Z, R, H, Q):\n", + " results = defaultdict(dict)\n", + " forward_pass = 0\n", + " backprop = 0\n", + " kalman_fn = pytensor.function(inputs=[data_sym, a0_sym, P0_sym, T_sym, Z_sym, H_sym, Q_sym],\n", + " outputs=(a_pred_seq, P_pred_seq))\n", + "\n", + " for _ in range(20):\n", + "\n", + " # --- forward pass ---\n", + " t0 = perf_counter()\n", + " a_pred, P_pred = kalman_fn(obs_data[:, np.newaxis],\n", + " a0,\n", + " P0,\n", + " T,\n", + " Z,\n", + " H,\n", + " R@Q@R.T,)\n", + " t1 = perf_counter()\n", + " forward_pass += (t1 - t0)/20\n", + " \n", + "\n", + " # --- Backprop ---\n", + " t0 = perf_counter()\n", + " _ = compute_grad_a0(\n", + " obs_data,\n", + " a0,\n", + " P0,\n", + " a_pred,\n", + " P_pred,\n", + " Z,\n", + " H,\n", + " T,)\n", + " t1 = perf_counter()\n", + " backprop += (t1 - t0)/20\n", + "\n", + " results[\"Forward pass\"] = forward_pass \n", + " results[\"Backprop\"] = backprop\n", + "\n", + " return results" + ] + }, + { + "cell_type": "markdown", + "id": "b413f411", + "metadata": {}, + "source": [ + "### Comparison" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "27a60fb3", + "metadata": {}, + "outputs": [], + "source": [ + "results = benchmark_kalman_gradients(loss, obs_data, a0, P0, T, Z, R, H, Q)" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "a413c8e9", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "defaultdict(, {'exec_time': 0.017576184973586352})\n" + ] + } + ], + "source": [ + "print(results)" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "d35b98d6", + "metadata": {}, + "outputs": [], + "source": [ + "results_op = benchmark_kalman_gradients(loss_op, obs_data, a0, P0, T, Z, R, H, Q)" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "539c18c2", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "defaultdict(, {'exec_time': 0.021262520016171044})\n" + ] + } + ], + "source": [ + "print(results_op)" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "1e633e75", + "metadata": {}, + "outputs": [], + "source": [ + "results_np = benchmark_kalman_gradients_np(a_pred_seq, P_pred_seq, obs_data, a0, P0, T, Z, R, H, Q)" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "7118dfec", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "defaultdict(, {'Forward pass': 0.002400995010975749, 'Backprop': 0.0018996200058609247})\n" + ] + } + ], + "source": [ + "print(results_np)" + ] + }, + { + "cell_type": "markdown", + "id": "d77cf70b", + "metadata": {}, + "source": [ + "## Error observation" + ] + }, + { + "cell_type": "markdown", + "id": "90fabd6f", + "metadata": {}, + "source": [ + "### Comparing the gradient with respect to a0" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "fbae0189", + "metadata": {}, + "outputs": [], + "source": [ + "# First the classic way with autodiff\n", + "\n", + "grad_list = pt.grad(loss, [a0_sym])\n", + "f_grad = pytensor.function(\n", + " inputs=[data_sym, a0_sym, P0_sym, T_sym, Z_sym, H_sym, Q_sym],\n", + " outputs=grad_list,\n", + ")\n", + "\n", + "grad_a0 = f_grad(obs_data[:, np.newaxis], a0, P0, T, Z, H, R @ Q @ R.T)\n", + "\n", + "# Now using our OpFromGraph custom gradient\n", + "\n", + "grad_list_op = pt.grad(loss_op, [a0_sym])\n", + "f_grad = pytensor.function(\n", + " inputs=[data_sym, a0_sym, P0_sym, T_sym, Z_sym, H_sym, Q_sym],\n", + " outputs=grad_list_op,\n", + ")\n", + "\n", + "grad_a0_op = f_grad(obs_data[:, np.newaxis], a0, P0, T, Z, H, R @ Q @ R.T)\n", + "\n", + "# And here using our handmaid numpy backprop\n", + "\n", + "kalman_fn = pytensor.function(inputs=[data_sym, a0_sym, P0_sym, T_sym, Z_sym, H_sym, Q_sym],\n", + " outputs=(a_pred_seq, P_pred_seq))\n", + "a_pred, P_pred = kalman_fn(obs_data[:, np.newaxis], a0, P0, T, Z, H, R@Q@R.T)\n", + "\n", + "grad_a0_np = compute_grad_a0(obs_data, a0, P0, a_pred, P_pred, Z, H, T)[0]" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "c3a114b2", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Comparison between classic a0 gradient and our custom OpFromGraph : True\n", + "Comparison between classic a0 gradient and our handmade NumPy backprop : True\n" + ] + } + ], + "source": [ + "print(\"Comparison between classic a0 gradient and our custom OpFromGraph :\", np.allclose(grad_a0, grad_a0_op))\n", + "print(\"Comparison between classic a0 gradient and our handmade NumPy backprop :\", np.allclose(grad_a0, grad_a0_np))" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "867d5e2f", + "metadata": {}, + "outputs": [], + "source": [ + "# First the classic way with autodiff\n", + "\n", + "grad_list = pt.grad(loss, [data_sym, a0_sym, P0_sym, T_sym, H_sym, Q_sym])\n", + "f_grad = pytensor.function(\n", + " inputs=[data_sym, a0_sym, P0_sym, T_sym, Z_sym, H_sym, Q_sym],\n", + " outputs=grad_list,\n", + ")\n", + "\n", + "grad_a0 = f_grad(obs_data[:, np.newaxis], a0, P0, T, Z, H, R @ Q @ R.T)\n", + "\n", + "# Now using our OpFromGraph custom gradient\n", + "\n", + "grad_list_op = pt.grad(loss_op, [data_sym, a0_sym, P0_sym, T_sym, H_sym, Q_sym])\n", + "f_grad = pytensor.function(\n", + " inputs=[data_sym, a0_sym, P0_sym, T_sym, Z_sym, H_sym, Q_sym],\n", + " outputs=grad_list_op,\n", + ")\n", + "\n", + "grad_a0_op = f_grad(obs_data[:, np.newaxis], a0, P0, T, Z, H, R @ Q @ R.T)" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "25f0a57b", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Comparison between classic y gradient and our custom OpFromGraph : True\n", + "Comparison between classic a0 gradient and our custom OpFromGraph : True\n", + "Comparison between classic P0 gradient and our custom OpFromGraph : True\n", + "Comparison between classic T gradient and our custom OpFromGraph : True\n", + "Comparison between classic H gradient and our custom OpFromGraph : True\n", + "Comparison between classic Q gradient and our custom OpFromGraph : True\n" + ] + } + ], + "source": [ + "print(\"Comparison between classic y gradient and our custom OpFromGraph :\", np.allclose(grad_a0[0], grad_a0_op[0]))\n", + "print(\"Comparison between classic a0 gradient and our custom OpFromGraph :\", np.allclose(grad_a0[1], grad_a0_op[1]))\n", + "print(\"Comparison between classic P0 gradient and our custom OpFromGraph :\", np.allclose((grad_a0[2] + grad_a0[2].T)/2, grad_a0_op[2]))\n", + "print(\"Comparison between classic T gradient and our custom OpFromGraph :\", np.allclose(grad_a0[3], grad_a0_op[3]))\n", + "print(\"Comparison between classic H gradient and our custom OpFromGraph :\", np.allclose(grad_a0[4], grad_a0_op[4]))\n", + "print(\"Comparison between classic Q gradient and our custom OpFromGraph :\", np.allclose((grad_a0[5] + grad_a0[5].T)/2, grad_a0_op[5]))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "CausalPy", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.13.3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}