diff --git a/doc/_thumbnails/autodiff/vector_jacobian_product.png b/doc/_thumbnails/autodiff/vector_jacobian_product.png new file mode 100644 index 0000000000..8b530c77cf Binary files /dev/null and b/doc/_thumbnails/autodiff/vector_jacobian_product.png differ diff --git a/doc/gallery/autodiff/vector_jacobian_product.ipynb b/doc/gallery/autodiff/vector_jacobian_product.ipynb new file mode 100644 index 0000000000..2ca6d2c7f2 --- /dev/null +++ b/doc/gallery/autodiff/vector_jacobian_product.ipynb @@ -0,0 +1,1071 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "4dd1aa8a-f595-42c1-ad75-716f497c726b", + "metadata": {}, + "source": [ + "# Vector Jacobian Product" + ] + }, + { + "cell_type": "markdown", + "id": "ff301ea1be9ef0a6", + "metadata": {}, + "source": [ + "At the core of autodiff is the vector-Jacobian product (VJP), or in PyTensor's archaic terminology, the L-Operator (because the vector is on the left). The Jacobian is the matrix (or tensor) of all first-order partial derivatives. Let us completely ignore what the vector means, and think how do we go about computing the product of a general vector with the Jacobian matrix?" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "35d314e1728148d5", + "metadata": { + "ExecuteTime": { + "end_time": "2025-07-06T20:20:19.996660Z", + "start_time": "2025-07-06T20:20:19.594782Z" + } + }, + "outputs": [], + "source": [ + "import pytensor.tensor as pt\n", + "from pytensor.gradient import Lop, jacobian as jacobian_raw\n", + "from pytensor.graph import rewrite_graph\n", + "import numpy as np" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "e805b85ceb6e8667", + "metadata": { + "ExecuteTime": { + "end_time": "2025-07-06T20:20:20.067703Z", + "start_time": "2025-07-06T20:20:20.065308Z" + } + }, + "outputs": [], + "source": [ + "def jacobian(*args, vectorize=True, **kwargs):\n", + " return jacobian_raw(*args, vectorize=vectorize, **kwargs)\n", + " \n", + "\n", + "def simplify_print(graph, **print_options):\n", + " rewrite_graph(graph, include=(\"fast_run\",), exclude=(\"inplace\", \"BlasOpt\")).dprint(**print_options)" + ] + }, + { + "cell_type": "markdown", + "id": "2d3afc0b692bd81", + "metadata": {}, + "source": [ + "## Elemtwise operations" + ] + }, + { + "cell_type": "markdown", + "id": "3b08aefde765313d", + "metadata": {}, + "source": [ + "The naive way is to create the full Jacobian matrix and then right-multiply it by the vector. Let's look at a concrete example for the Elemtwise operation log(x)." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "153da4ec5f8b5a71", + "metadata": { + "ExecuteTime": { + "end_time": "2025-07-06T20:20:20.128996Z", + "start_time": "2025-07-06T20:20:20.124242Z" + } + }, + "outputs": [], + "source": [ + "x = pt.vector(\"x\")\n", + "log_x = pt.log(x)" + ] + }, + { + "cell_type": "markdown", + "id": "91012e92f0eafee9", + "metadata": {}, + "source": [ + "If x has length 3, the Jacobian of y with respect to x is a 3x3 matrix, since there are 3 outputs and 3 inputs.\n", + "\n", + "Each entry contains the partial derivative of a one of the outputs (rows) with respect to one of the inputs (columns).\n", + "\n", + "$$\n", + "J = \\begin{pmatrix}\n", + "\\frac{\\partial y_1}{\\partial x_1} & \\frac{\\partial y_1}{\\partial x_2} & \\frac{\\partial y_1}{\\partial x_3} \\\\\n", + "\\frac{\\partial y_2}{\\partial x_1} & \\frac{\\partial y_2}{\\partial x_2} & \\frac{\\partial y_2}{\\partial x_3} \\\\\n", + "\\frac{\\partial y_3}{\\partial x_1} & \\frac{\\partial y_3}{\\partial x_2} & \\frac{\\partial y_3}{\\partial x_3}\n", + "\\end{pmatrix}\n", + "$$\n", + "\n", + "For the elementwise operation log(x), the Jacobian is a diagonal matrix, as each input affects only the corresponding output. For the log operation the partial derivatives are given by $\\frac{1}{x_i}$, so the Jacobian is:\n", + "\n", + "$$\n", + "J = \\begin{pmatrix}\n", + "\\frac{1}{x_1} & 0 & 0 \\\\\n", + "0 & \\frac{1}{x_2} & 0 \\\\\n", + "0 & 0 & \\frac{1}{x_3}\n", + "\\end{pmatrix}\n", + "$$" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "7c359f386b4cb918", + "metadata": { + "ExecuteTime": { + "end_time": "2025-07-06T20:20:20.195333Z", + "start_time": "2025-07-06T20:20:20.172499Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "True_div [id A]\n", + " ├─ Eye{dtype='float64'} [id B]\n", + " │ ├─ Shape_i{0} [id C]\n", + " │ │ └─ x [id D]\n", + " │ ├─ Shape_i{0} [id C]\n", + " │ │ └─ ···\n", + " │ └─ 0 [id E]\n", + " └─ ExpandDims{axis=0} [id F]\n", + " └─ x [id D]\n" + ] + } + ], + "source": [ + "J_log = jacobian(log_x, x)\n", + "simplify_print(J_log)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "a3441a71c6ddb3ad", + "metadata": { + "ExecuteTime": { + "end_time": "2025-07-06T20:20:20.517814Z", + "start_time": "2025-07-06T20:20:20.237690Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[1. , 0. , 0. ],\n", + " [0. , 0.5 , 0. ],\n", + " [0. , 0. , 0.33333333]])" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "J_log.eval({\"x\": [1.0, 2.0, 3.0]})" + ] + }, + { + "cell_type": "markdown", + "id": "f57772bc5c85c82c", + "metadata": {}, + "source": [ + "To get the vector-Jacobian product, we will left-multiply the Jacobian by a vector v. In this case, it simplifies to an elementwise division of the vector v by the input vector x:\n", + "\n", + "$$\n", + "v^T \\cdot J = \\begin{pmatrix}\n", + "v_1 \\\\ v_2 \\\\ v_3\n", + "\\end{pmatrix}^T \\cdot \\begin{pmatrix}\n", + "\\frac{1}{x_1} & 0 & 0 \\\\\n", + "0 & \\frac{1}{x_2} & 0 \\\\\n", + "0 & 0 & \\frac{1}{x_3}\n", + "\\end{pmatrix} = \\begin{pmatrix}\n", + "\\frac{v_1}{x_1} \\\\ \\frac{v_2}{x_2} \\\\ \\frac{v_3}{x_3}\n", + "\\end{pmatrix}\n", + "$$\n", + "\n", + "It is unnecessary to compute the whole Jacobian matrix, and then perform a vector-matrix multiplication. The `Lop` returns the smart computations directly:" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "776b81a3c3039098", + "metadata": { + "ExecuteTime": { + "end_time": "2025-07-06T20:20:20.968461Z", + "start_time": "2025-07-06T20:20:20.953510Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "True_div [id A]\n", + " ├─ v [id B]\n", + " └─ x [id C]\n" + ] + } + ], + "source": [ + "v = pt.vector(\"v\")\n", + "vjp_log = Lop(log_x, wrt=x, eval_points=v)\n", + "simplify_print(vjp_log)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "b42553ba90bb69e9", + "metadata": { + "ExecuteTime": { + "end_time": "2025-07-06T20:20:21.211452Z", + "start_time": "2025-07-06T20:20:21.043061Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "array([4. , 2.5, 2. ])" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "vjp_log.eval({\"x\": [1.0, 2.0, 3.0], \"v\": [4.0, 5.0, 6.0]})" + ] + }, + { + "cell_type": "markdown", + "id": "cf099c75e01f75f8", + "metadata": {}, + "source": [ + "## Cumsum operation" + ] + }, + { + "cell_type": "markdown", + "id": "a662a8e6b30e079c", + "metadata": {}, + "source": [ + "A pattern that will become obvious in this notebook is that we can often exploit some property of the Jacobian matrix (and that we want to multiply it by a vector) to compute the VJP cheaply. Let's take a look at the cumulative sum operation." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "f1c068fed4abb1a7", + "metadata": { + "ExecuteTime": { + "end_time": "2025-07-06T20:20:21.240090Z", + "start_time": "2025-07-06T20:20:21.227054Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "array([1., 3., 6.])" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "x = pt.vector(\"x\")\n", + "cumsum_x = pt.cumsum(x)\n", + "cumsum_x.eval({\"x\": [1.0, 2.0, 3.0]})" + ] + }, + { + "cell_type": "markdown", + "id": "a9b41bc35d7c1716", + "metadata": {}, + "source": [ + "The jacobian of the cumulative sum operation is a lower triangular matrix of ones, since the first input affects all outputs additively, the second input affects all outputs but the first, and so on, until the last input which only affects the last output. If x has length 3:\n", + "\n", + "$$\n", + "J = \\begin{pmatrix}\n", + "1 & 0 & 0 \\\\\n", + "1 & 1 & 0 \\\\\n", + "1 & 1 & 1 \\\\\n", + "\\end{pmatrix}\n", + "$$" + ] + }, + { + "cell_type": "markdown", + "id": "87a87483-e92a-42aa-b42f-509739063888", + "metadata": {}, + "source": [ + "PyTensor autodiff builds this jacobian in a funny way. Starting from a diagonal matrix, it flips the columns, performs a cumsum across the them and then flips them back. A more direct way would do cumsum along the row of the diagonal matrix, but since a flip is just a view (no copy needed) it doesn't actually cost us much." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "1a5e9c70f2681436", + "metadata": { + "ExecuteTime": { + "end_time": "2025-07-06T20:20:21.366061Z", + "start_time": "2025-07-06T20:20:21.290442Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Subtensor{:, ::step} [id A]\n", + " ├─ CumOp{1, add} [id B]\n", + " │ └─ Subtensor{:, ::step} [id C]\n", + " │ ├─ Eye{dtype='float64'} [id D]\n", + " │ │ ├─ Shape_i{0} [id E]\n", + " │ │ │ └─ x [id F]\n", + " │ │ ├─ Shape_i{0} [id E]\n", + " │ │ │ └─ ···\n", + " │ │ └─ 0 [id G]\n", + " │ └─ -1 [id H]\n", + " └─ -1 [id H]\n" + ] + } + ], + "source": [ + "J_cumsum = jacobian(cumsum_x, x)\n", + "simplify_print(J_cumsum)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "2c932d61-40af-4c65-a45b-d397d5b0e3a0", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[1, 0, 0],\n", + " [1, 1, 0],\n", + " [1, 1, 1]])" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "J_cumsum.eval({\"x\": [1.0, 2.0, 3.0]}).astype(int)" + ] + }, + { + "cell_type": "markdown", + "id": "789fe3a70f42f6c5", + "metadata": {}, + "source": [ + "The left-multiplication of the Jacobian by a vector v has a special structure as well. Let's write it down:\n", + "\n", + "$$\n", + "v^T \\cdot J = \\begin{pmatrix}\n", + "v_1 \\\\ v_2 \\\\ v_3\n", + "\\end{pmatrix}^T \\cdot \\begin{pmatrix}\n", + "1 & 0 & 0 \\\\\n", + "1 & 1 & 0 \\\\\n", + "1 & 1 & 1 \\\\\n", + "\\end{pmatrix} = \\begin{pmatrix}\n", + "v_1 + v2 + v 3 \\\\ v_2 + v_3 \\\\ v_3\n", + "\\end{pmatrix}\n", + "$$" + ] + }, + { + "cell_type": "markdown", + "id": "530901210f4afebe", + "metadata": {}, + "source": [ + "The final result is a cumulative sum of the vector v, but in reverse order." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "e4611d5f792ecd5e", + "metadata": { + "ExecuteTime": { + "end_time": "2025-07-06T20:24:45.832808Z", + "start_time": "2025-07-06T20:24:45.809493Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Subtensor{::step} [id A]\n", + " ├─ CumOp{None, add} [id B]\n", + " │ └─ Subtensor{::step} [id C]\n", + " │ ├─ v [id D]\n", + " │ └─ -1 [id E]\n", + " └─ -1 [id E]\n" + ] + } + ], + "source": [ + "v = pt.vector(\"v\")\n", + "vjp_cumsum = Lop(cumsum_x, x, v)\n", + "simplify_print(vjp_cumsum)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "e7695008-1622-45cb-ae36-5b9187d2abcf", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([3., 2., 1.])" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "vjp_cumsum.eval({\"x\": [1.0, 2.0, 3.0], \"v\": [1, 1, 1]})" + ] + }, + { + "cell_type": "markdown", + "id": "d5a03eb5545625b5", + "metadata": {}, + "source": [ + "## Convolution operation" + ] + }, + { + "cell_type": "markdown", + "id": "7b27b04c-1abe-4f78-ad4a-35ff35614093", + "metadata": {}, + "source": [ + "Next, we shall look at an operation with two inputs - the discrete convolution." + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "43f42d4e-46cf-4a19-97c3-5c9a6888d824", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([ 0., 1., 1., -2.])" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "x = pt.vector(\"x\")\n", + "y = pt.vector(\"y\", shape=(2,))\n", + "convolution_xy = pt.signal.convolve1d(x, y, mode=\"full\")\n", + "convolution_xy.eval({\"x\": [0, 1, 2], \"y\": [1, -1]})" + ] + }, + { + "cell_type": "markdown", + "id": "26224b6d-0164-4421-989d-b277e1ec38cc", + "metadata": {}, + "source": [ + "If you're not familiar with convolution, we get those four numbers by padding `x` with zeros and then performing an inner product with the flipped `y`, one pair of values at a time" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "56075ceb-dc1d-4e17-8e95-e43dbc28783a", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([ 0, 1, 1, -2])" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "x_padded = np.array([0, 0, 1, 2, 0])\n", + "res = np.array([\n", + " x_padded[0:2] @ [-1, 1],\n", + " x_padded[1:3] @ [-1, 1],\n", + " x_padded[2:4] @ [-1, 1],\n", + " x_padded[3:5] @ [-1, 1],\n", + "])\n", + "res" + ] + }, + { + "cell_type": "markdown", + "id": "bdb3a78a-ca4e-4adc-98f6-65663ffd0ec6", + "metadata": {}, + "source": [ + "Let's focus on the Jacobian wrt to y, as that's smaller. If you look at the expression above you'll see that it implies the following jacobian:\n", + "\n", + "$$\n", + "J = \\begin{pmatrix}\n", + "x_1 & 0 \\\\\n", + "x_2 & x_1 \\\\\n", + "x_3 & x_2 \\\\\n", + "0 & x_3 \\\\\n", + "\\end{pmatrix}\n", + "$$\n", + "\n", + "The constant zeros come from the padding. Curious how PyTensor builds this sort of jacobian?" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "aee5b3ed-1533-479a-abba-6375435c5268", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Blockwise{Convolve1d, (n),(k),()->(o)} [id A]\n", + " ├─ Eye{dtype='float64'} [id B]\n", + " │ ├─ Add [id C]\n", + " │ │ ├─ 1 [id D]\n", + " │ │ └─ Shape_i{0} [id E]\n", + " │ │ └─ x [id F]\n", + " │ ├─ Add [id C]\n", + " │ │ └─ ···\n", + " │ └─ 0 [id G]\n", + " ├─ ExpandDims{axis=0} [id H]\n", + " │ └─ Subtensor{::step} [id I]\n", + " │ ├─ x [id F]\n", + " │ └─ -1 [id J]\n", + " └─ [False] [id K]\n" + ] + } + ], + "source": [ + "J_convolution = jacobian(convolution_xy, y)\n", + "simplify_print(J_convolution)" + ] + }, + { + "cell_type": "markdown", + "id": "046b91b7-47cb-492e-9a7b-5c92de0cab93", + "metadata": {}, + "source": [ + "It performs a batched \"valid\" convolution between eye(4) and the flipped x vector. In a valid convolution, there is no padding, and we only multiply the sub-sequences that match in length." + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "1ab05d9a-33fd-46a9-abfd-4de25cd6e520", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[0., 0.],\n", + " [1., 0.],\n", + " [2., 1.],\n", + " [0., 2.]])" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "J_convolution.eval({\"x\": [0, 1, 2]})" + ] + }, + { + "cell_type": "markdown", + "id": "05577d83-091a-48fb-a383-854f69c4cab5", + "metadata": {}, + "source": [ + "Following the theme, is there any special structure in this Jacobian that can be exploited to compute VJP efficiently?" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "89caf646-0213-4aa6-9aca-b7cb0230c8bc", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Convolve1d [id A]\n", + " ├─ v [id B]\n", + " ├─ Subtensor{::step} [id C]\n", + " │ ├─ x [id D]\n", + " │ └─ -1 [id E]\n", + " └─ ScalarFromTensor [id F]\n", + " └─ False [id G]\n" + ] + } + ], + "source": [ + "v = pt.vector(\"v\", shape=(4,))\n", + "vjp_convolution = Lop(convolution_xy, y, v)\n", + "simplify_print(vjp_convolution)" + ] + }, + { + "cell_type": "markdown", + "id": "6b369467-9ceb-4f4f-86f7-486801ff3275", + "metadata": {}, + "source": [ + "It's just the \"valid\" convolution between v and x flipped. Our Jacobian has a [toeplitz structure](https://en.wikipedia.org/wiki/Toeplitz_matrix), and the dot product between such a matrix and a vector is equivalent to a discrete convolution!" + ] + }, + { + "cell_type": "markdown", + "id": "4d568e5b-310d-4a4e-a89e-198a337a379b", + "metadata": {}, + "source": [ + "$$\n", + "v^T \\cdot J = \\begin{pmatrix}\n", + "v_1 \\\\ v_2 \\\\ v_3 \\\\ v4\n", + "\\end{pmatrix}^T \\cdot \\begin{pmatrix}\n", + "x_1 & 0 \\\\\n", + "x_2 & x_1 \\\\\n", + "x_3 & x_2 \\\\\n", + "0 & x_3 \\\\\n", + "\\end{pmatrix}\n", + "= v \\ast x_{[::-1]}\n", + "$$" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "b4811469-cb72-4303-98a9-cc75140b323d", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([ 8., 11.])" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "vjp_convolution.eval({\"v\": [1, 2, 3, 4], \"x\": [0, 1, 2]})" + ] + }, + { + "cell_type": "markdown", + "id": "d5623e88-9210-4d2c-bcbd-6d37b991a01d", + "metadata": {}, + "source": [ + "## Transpose operation\n", + "\n", + "For a final example let's look at matrix tranposition. This is a simple operation, but is no longer a vector function." + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "cd3d779b-6bc3-4462-9904-82f6177ad839", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(None, 2)" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "A = pt.matrix(\"A\", shape=(2, None))\n", + "transpose_A = A.T\n", + "transpose_A.type.shape" + ] + }, + { + "cell_type": "markdown", + "id": "60a08160-10b8-494d-8031-621a8e7d0521", + "metadata": {}, + "source": [ + "To be able to think about the Jacobian (and then the VJP) we need to look at this operation in terms of raveled input and outputs." + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "ab4fc0f0-882e-46ff-8129-08731e26a025", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([0., 3., 1., 4., 2., 5.])" + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "transpose_A.ravel().eval({\"A\": np.arange(6).reshape(2, 3)})" + ] + }, + { + "cell_type": "markdown", + "id": "e3043a1b-179e-49c1-9e3e-587740b8ac7d", + "metadata": {}, + "source": [ + "The Jacobian is then a (6 x 6) permutation matrix like this:\n", + "\n", + "$$\n", + "J = \\begin{pmatrix}\n", + "1 & 0 & 0 & 0 & 0 & 0\\\\\n", + "0 & 0 & 0 & 1 & 0 & 0\\\\\n", + "0 & 1 & 0 & 0 & 0 & 0\\\\\n", + "0 & 0 & 0 & 0 & 1 & 0\\\\\n", + "0 & 0 & 1 & 0 & 0 & 0\\\\\n", + "0 & 0 & 0 & 0 & 0 & 1\\\\\n", + "\\end{pmatrix}\n", + "$$" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "4765b01a-8078-48ac-ad22-1ce1dabf9ea7", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[1., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 1., 0., 0.],\n", + " [0., 1., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 1., 0.],\n", + " [0., 0., 1., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 1.]])" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "J_transpose = jacobian(transpose_A.ravel(), A).reshape((6, 6))\n", + "J_transpose.eval({\"A\": np.zeros((2, 3))})" + ] + }, + { + "cell_type": "markdown", + "id": "db3f0b31-9514-4724-ac77-755bbb0c2be7", + "metadata": {}, + "source": [ + "PyTensor builds this Jacobian with two reshapes and a tranpose of an `eye`." + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "c65f6740-f809-4fc0-831f-61bd2e09f510", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Reshape{2} [id A]\n", + " ├─ Transpose{axes=[0, 2, 1]} [id B]\n", + " │ └─ Reshape{3} [id C]\n", + " │ ├─ Eye{dtype='float64'} [id D]\n", + " │ │ ├─ Mul [id E]\n", + " │ │ │ ├─ 2 [id F]\n", + " │ │ │ └─ Shape_i{1} [id G]\n", + " │ │ │ └─ A [id H]\n", + " │ │ ├─ Mul [id E]\n", + " │ │ │ └─ ···\n", + " │ │ └─ 0 [id I]\n", + " │ └─ MakeVector{dtype='int64'} [id J]\n", + " │ ├─ Mul [id E]\n", + " │ │ └─ ···\n", + " │ ├─ Shape_i{1} [id G]\n", + " │ │ └─ ···\n", + " │ └─ 2 [id F]\n", + " └─ [6 6] [id K]\n" + ] + } + ], + "source": [ + "simplify_print(J_transpose)" + ] + }, + { + "cell_type": "markdown", + "id": "0e39d268-12c3-463f-9ec7-e8b48bbd95d4", + "metadata": {}, + "source": [ + "To recreate the outcome of `Lop`, we ravel the `V` matrix, multiply it with the Jacobian defined on the raveled function, and reshape the result to the original input shape." + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "2d09cb7b-60f1-4d22-adbe-fb88d487ba86", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[0., 2., 4.],\n", + " [1., 3., 5.]])" + ] + }, + "execution_count": 23, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "V = pt.matrix(\"V\", shape=(3, 2))\n", + "naive_vjp_transpose = (V.ravel() @ J_transpose).reshape(V.shape[::-1])\n", + "\n", + "vjp_eval_dict = {\"V\": np.arange(6).reshape((3, 2)), \"A\": np.zeros((2, 3))}\n", + "naive_vjp_transpose.eval(vjp_eval_dict)" + ] + }, + { + "cell_type": "markdown", + "id": "9164ca89-a49c-4768-a5c4-f16e16ffc0b4", + "metadata": {}, + "source": [ + "Because J is a permutation matrix, the multiplication with it simply rearranges the entries of `V`. \n", + "\n", + "What's more, after the reshape, we end up with a simple transposition of the original `V` matrix!\n", + "\n", + "Unsurprisingly, `Lop` takes the direct shortcut:" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "e29d3283-77f6-4539-8d41-851848267cd6", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Transpose{axes=[1, 0]} [id A]\n", + " └─ V [id B]\n" + ] + }, + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 24, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "Lop(transpose_A, A, V).dprint()" + ] + }, + { + "cell_type": "markdown", + "id": "6de26bd3-4e43-444a-abf9-eee682d201b0", + "metadata": {}, + "source": [ + "## VJP and auto-diff" + ] + }, + { + "cell_type": "markdown", + "id": "4f92c6a4-27d2-4854-8b66-038eca62ebb7", + "metadata": {}, + "source": [ + "It is time to reveal the meaning of the mysterious vector (or reshaped tensor) `v`. In the context ouf auto-diff, it is the vector that accumulates the partial derivatives of intermediate computations. If you chain the VJP for each operation in your graph you obtain reverse-mode autodiff. \n", + "\n", + "Let's look at a simple example with the operations we discussed already:" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "1a8a50dd-d62b-4229-b885-e96f4e43f76e", + "metadata": {}, + "outputs": [], + "source": [ + "x = pt.vector(\"x\")\n", + "log_x = pt.log(x)\n", + "cumsum_log_x = pt.cumsum(log_x)" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "8a32e30e-5885-49e1-8973-e734d66f05b2", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "True_div [id A]\n", + " ├─ Subtensor{::step} [id B]\n", + " │ ├─ CumOp{None, add} [id C]\n", + " │ │ └─ Alloc [id D]\n", + " │ │ ├─ [1.] [id E]\n", + " │ │ └─ Shape_i{0} [id F]\n", + " │ │ └─ x [id G]\n", + " │ └─ -1 [id H]\n", + " └─ x [id G]\n" + ] + } + ], + "source": [ + "grad_out_wrt_x = pt.grad(cumsum_log_x.sum(), x)\n", + "simplify_print(grad_out_wrt_x)" + ] + }, + { + "cell_type": "markdown", + "id": "5c788690-c1b9-4462-b78f-b4c16bfdc0b0", + "metadata": {}, + "source": [ + "You may recognize the gradient components from the examples above. The gradient simplifies to `cumsum(ones_like(x))[::-1] / x`\n", + "\n", + "We can build the same graph manually, by chaining two `Lop` calls and setting the initial `grad_vec` to `ones` with the right shape." + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "id": "b7b81d74-fc84-4de9-97df-0cb47b745825", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "True_div [id A]\n", + " ├─ Subtensor{::step} [id B]\n", + " │ ├─ CumOp{None, add} [id C]\n", + " │ │ └─ Alloc [id D]\n", + " │ │ ├─ [1.] [id E]\n", + " │ │ └─ Shape_i{0} [id F]\n", + " │ │ └─ x [id G]\n", + " │ └─ -1 [id H]\n", + " └─ x [id G]\n" + ] + } + ], + "source": [ + "grad_vec = pt.ones_like(cumsum_log_x)\n", + "grad_out_wrt_x = Lop(log_x, x, Lop(cumsum_log_x, log_x, grad_vec))\n", + "simplify_print(grad_out_wrt_x)" + ] + }, + { + "cell_type": "markdown", + "id": "a2e9e121-7363-41f9-9c1d-d6792db36bcc", + "metadata": {}, + "source": [ + "Similarly, forward-mode autodiff makes use of the R-Operator (Rop) or Jacobian-vector product (JVP) to accumulate the partial derivations from inputs to outputs." + ] + }, + { + "cell_type": "markdown", + "id": "eff94009-64f2-4c5d-b4d5-5ff36a86b7fa", + "metadata": {}, + "source": [ + "## Conclusion" + ] + }, + { + "cell_type": "markdown", + "id": "123a10f2-ef4d-4cd9-86b7-b2b7975a1c31", + "metadata": {}, + "source": [ + "We hope this sheds some light on how PyTensor (and most auto-diff frameworks) implement vector Jacobian products efficiently, in a way that avoids both having to build the full jacobian and having to multiply with it." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0614f4ce-6a30-4fdc-8d74-bea8203ff3c4", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.8" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/pytensor/tensor/extra_ops.py b/pytensor/tensor/extra_ops.py index dc92238010..c27602b7b1 100644 --- a/pytensor/tensor/extra_ops.py +++ b/pytensor/tensor/extra_ops.py @@ -13,6 +13,7 @@ ) from pytensor.graph.basic import Apply, Constant, Variable from pytensor.graph.op import Op +from pytensor.graph.replace import _vectorize_node from pytensor.link.c.op import COp from pytensor.link.c.params_type import ParamsType from pytensor.link.c.type import EnumList, Generic @@ -360,7 +361,7 @@ def grad(self, inputs, output_gradients): ) def infer_shape(self, fgraph, node, shapes): - if self.axis is None: + if self.axis is None and len(shapes[0]) > 1: return [(prod(shapes[0]),)] # Flatten return shapes @@ -473,6 +474,25 @@ def cumprod(x, axis=None): return CumOp(axis=axis, mode="mul")(x) +@_vectorize_node.register(CumOp) +def vectorize_cum_op(op: CumOp, node: Apply, batch_x): + """Vectorize the CumOp to work on a batch of inputs.""" + [original_x] = node.inputs + batch_ndim = batch_x.ndim - original_x.ndim + axis = op.axis + if axis is None and original_x.ndim == 1: + axis = 0 + elif axis is not None: + axis = normalize_axis_index(op.axis, original_x.ndim) + + if axis is None: + # Ravel all unbatched dimensions and perform CumOp on the last axis + batch_x_raveled = [batch_x.flatten(ndim=batch_ndim + 1) for x in batch_x] + return type(op)(axis=-1, mode=op.mode).make_node(batch_x_raveled) + else: + return type(op)(axis=axis + batch_ndim, mode=op.mode).make_node(batch_x) + + def diff(x, n=1, axis=-1): """Calculate the `n`-th order discrete difference along the given `axis`. diff --git a/pytensor/tensor/subtensor.py b/pytensor/tensor/subtensor.py index 6c881a0312..c9961e0564 100644 --- a/pytensor/tensor/subtensor.py +++ b/pytensor/tensor/subtensor.py @@ -948,6 +948,9 @@ def perform(self, node, inputs, out_): out[0] = np.asarray(x.__getitem__(cdata)) def infer_shape(self, fgraph, node, shapes): + def _is_constant(const, x): + return isinstance(const, Constant) and const.data.item() == x + xshp = shapes[0] assert len(xshp) == node.inputs[0].ndim outshp = [] @@ -961,10 +964,17 @@ def infer_shape(self, fgraph, node, shapes): # If it is the default (None, None, None) slice, or a variant, # the shape will be xl if ( - (idx.start in [None, 0]) - and (idx.stop in [None, sys.maxsize]) - and (idx.step is None or idx.step == 1) + (idx.start is None or _is_constant(idx.start, 0)) + and (idx.stop is None or _is_constant(idx.stop, sys.maxsize)) + and (idx.step is None or _is_constant(idx.step, 1)) + ): + outshp.append(xl) + elif ( + (idx.start is None) + and (idx.stop is None) + and _is_constant(idx.step, -1) ): + # Reverse slice outshp.append(xl) else: cnf = get_canonical_form_slice(idx, xl)[0] diff --git a/scripts/generate_gallery.py b/scripts/generate_gallery.py index 15e94ca7f4..3409cb1b3d 100644 --- a/scripts/generate_gallery.py +++ b/scripts/generate_gallery.py @@ -57,6 +57,7 @@ folder_title_map = { "introduction": "Introduction", "rewrites": "Graph Rewriting", + "autodiff": "Automatic Differentiation", "scan": "Looping in Pytensor", "optimize": "Optimization in Pytensor", }