|
7 | 7 | "source": [
|
8 | 8 | "# Neural Tangent Kernels\n",
|
9 | 9 | "\n",
|
10 |
| - "The neural tangent kernel (NTK) is a kernel that describes [how a neural network evolves during training](https://en.wikipedia.org/wiki/Neural_tangent_kernel). There has been a lot of research around it [in recent years](https://arxiv.org/abs/1806.07572). This tutorial, inspired by the implementation of [NTKs in JAX](https://github.com/google/neural-tangents), demonstrates how to easily compute this quantity using functorch." |
| 10 | + "The neural tangent kernel (NTK) is a kernel that describes [how a neural network evolves during training](https://en.wikipedia.org/wiki/Neural_tangent_kernel). There has been a lot of research around it [in recent years](https://arxiv.org/abs/1806.07572). This tutorial, inspired by the implementation of [NTKs in JAX](https://github.com/google/neural-tangents) (see [Fast Finite Width Neural Tangent Kernel](https://arxiv.org/abs/2206.08720) for details), demonstrates how to easily compute this quantity using functorch." |
11 | 11 | ]
|
12 | 12 | },
|
13 | 13 | {
|
|
79 | 79 | "\n",
|
80 | 80 | "functorch transforms operate on functions. In particular, to compute the NTK, we will need a function that accepts the parameters of the model and a single input (as opposed to a batch of inputs!) and returns a single output.\n",
|
81 | 81 | "\n",
|
82 |
| - "We'll use functorch's make_functional to accomplish the first step. If your module has buffers, you'll want to use make_functional_with_buffers instead." |
| 82 | + "We'll use functorch's `make_functional` to accomplish the first step. If your module has buffers, you'll want to use `make_functional_with_buffers` instead." |
83 | 83 | ]
|
84 | 84 | },
|
85 | 85 | {
|
|
117 | 117 | "id": "62bc6b5a-31fa-411e-8069-e6c1f6d05248",
|
118 | 118 | "metadata": {},
|
119 | 119 | "source": [
|
120 |
| - "## Compute the NTK: method 1\n", |
| 120 | + "## Compute the NTK: method 1 (Jacobian contraction)\n", |
121 | 121 | "\n",
|
122 |
| - "We're ready to compute the empirical NTK. The empirical NTK for two data points `x1` and `x2` is defined as an inner product between the Jacobian of the model evaluated at `x1` and the Jacobian of the model evaluated at `x2`:\n", |
| 122 | + "We're ready to compute the empirical NTK. The empirical NTK for two data points $x_1$ and $x_2$ is defined as the matrix product between the Jacobian of the model evaluated at $x_1$ and the Jacobian of the model evaluated at $x_2$:\n", |
123 | 123 | "\n",
|
124 |
| - "$$J_{net}(x1) \\cdot J_{net}^T(x2)$$\n", |
| 124 | + "$$J_{net}(x_1) J_{net}^T(x_2)$$\n", |
125 | 125 | "\n",
|
126 |
| - "In the batched case where `x1` is a batch of data points and `x2` is a batch of data points, then we want the inner product between the Jacobians of all combinations of data points from `x1` and `x2`. Here's how to compute the NTK in the batched case:" |
| 126 | + "In the batched case where $x_1$ is a batch of data points and $x_2$ is a batch of data points, then we want the matrix product between the Jacobians of all combinations of data points from $x_1$ and $x_2$.\n", |
| 127 | + "\n", |
| 128 | + "The first method consists of doing just that - computing the two Jacobians, and contracting them. Here's how to compute the NTK in the batched case:" |
127 | 129 | ]
|
128 | 130 | },
|
129 | 131 | {
|
|
133 | 135 | "metadata": {},
|
134 | 136 | "outputs": [],
|
135 | 137 | "source": [
|
136 |
| - "def empirical_ntk(fnet_single, params, x1, x2):\n", |
| 138 | + "def empirical_ntk_jacobian_contraction(fnet_single, params, x1, x2):\n", |
137 | 139 | " # Compute J(x1)\n",
|
138 | 140 | " jac1 = vmap(jacrev(fnet_single), (None, 0))(params, x1)\n",
|
139 | 141 | " jac1 = [j.flatten(2) for j in jac1]\n",
|
|
163 | 165 | }
|
164 | 166 | ],
|
165 | 167 | "source": [
|
166 |
| - "result = empirical_ntk(fnet_single, params, x_train, x_test)\n", |
| 168 | + "result = empirical_ntk_jacobian_contraction(fnet_single, params, x_train, x_test)\n", |
167 | 169 | "print(result.shape)"
|
168 | 170 | ]
|
169 | 171 | },
|
|
182 | 184 | "metadata": {},
|
183 | 185 | "outputs": [],
|
184 | 186 | "source": [
|
185 |
| - "def empirical_ntk(fnet_single, params, x1, x2, compute='full'):\n", |
| 187 | + "def empirical_ntk_jacobian_contraction(fnet_single, params, x1, x2, compute='full'):\n", |
186 | 188 | " # Compute J(x1)\n",
|
187 | 189 | " jac1 = vmap(jacrev(fnet_single), (None, 0))(params, x1)\n",
|
188 | 190 | " jac1 = [j.flatten(2) for j in jac1]\n",
|
|
222 | 224 | }
|
223 | 225 | ],
|
224 | 226 | "source": [
|
225 |
| - "result = empirical_ntk(fnet_single, params, x_train, x_test, 'trace')\n", |
| 227 | + "result = empirical_ntk_jacobian_contraction(fnet_single, params, x_train, x_test, 'trace')\n", |
226 | 228 | "print(result.shape)"
|
227 | 229 | ]
|
228 | 230 | },
|
| 231 | + { |
| 232 | + "cell_type": "markdown", |
| 233 | + "id": "6c941e5d-51d7-47b2-80ee-edcd4aee6aaa", |
| 234 | + "metadata": {}, |
| 235 | + "source": [ |
| 236 | + "The asymptotic time complexity of this method is $N O [FP]$ (time to compute the Jacobians) $ + N^2 O^2 P$ (time to contract the Jacobians), where $N$ is the batch size of $x_1$ and $x_2$, $O$ is the model's output size, $P$ is the total number of parameters, and $[FP]$ is the cost of a single forward pass through the model. See section section 3.2 in [Fast Finite Width Neural Tangent Kernel](https://arxiv.org/abs/2206.08720) for details." |
| 237 | + ] |
| 238 | + }, |
229 | 239 | {
|
230 | 240 | "cell_type": "markdown",
|
231 | 241 | "id": "6c931e5d-51d7-47b2-80ee-ddcd4aee6aaa",
|
232 | 242 | "metadata": {},
|
233 | 243 | "source": [
|
234 |
| - "## Compute the NTK: method 2\n", |
| 244 | + "## Compute the NTK: method 2 (NTK-vector products)\n", |
235 | 245 | "\n",
|
236 |
| - "The next method we will discuss is a way to compute the NTK implicitly. This has different tradeoffs compared to the previous one and it is generally more efficient when your model has large parameters; we recommend trying out both methods to see which works better.\n", |
| 246 | + "The next method we will discuss is a way to compute the NTK using NTK-vector products.\n", |
237 | 247 | "\n",
|
238 |
| - "Here's our definition of NTK:\n", |
| 248 | + "This method reformulates NTK as a stack of NTK-vector products applied to columns of an identity matrix $I_O$ of size $O\\times O$ (where $O$ is the output size of the model):\n", |
239 | 249 | "\n",
|
240 |
| - "$$J_{net}(x1) \\cdot J_{net}^T(x2)$$\n", |
| 250 | + "$$J_{net}(x_1) J_{net}^T(x_2) = J_{net}(x_1) J_{net}^T(x_2) I_{O} = \\left[J_{net}(x_1) \\left[J_{net}^T(x_2) e_o\\right]\\right]_{o=1}^{O},$$\n", |
| 251 | + "where $e_o\\in \\mathbb{R}^O$ are column vectors of the identity matrix $I_O$.\n", |
241 | 252 | "\n",
|
242 |
| - "The implicit computation reformulates the problem by adding an identity matrix and rearranging the matrix-multiplies:\n", |
| 253 | + "- Let $\\textrm{vjp}_o = J_{net}^T(x_2) e_o$. We can use a vector-Jacobian product to compute this.\n", |
| 254 | + "- Now, consider $J_{net}(x_1) \\textrm{vjp}_o$. This is a Jacobian-vector product!\n", |
| 255 | + "- Finally, we can run the above computation in parallel over all columns $e_o$ of $I_O$ using `vmap`.\n", |
243 | 256 | "\n",
|
244 |
| - "$$= J_{net}(x1) \\cdot J_{net}^T(x2) \\cdot I$$\n", |
245 |
| - "$$= (J_{net}(x1) \\cdot (J_{net}^T(x2) \\cdot I))$$\n", |
| 257 | + "This suggests that we can use a combination of reverse-mode AD (to compute the vector-Jacobian product) and forward-mode AD (to compute the Jacobian-vector product) to compute the NTK.\n", |
246 | 258 | "\n",
|
247 |
| - "- Let $vjps = (J_{net}^T(x2) \\cdot I)$. We can use a vector-Jacobian product to compute this.\n", |
248 |
| - "- Now, consider $J_{net}(x1) \\cdot vjps$. This is a Jacobian-vector product!\n", |
249 |
| - "\n", |
250 |
| - "This suggests that we can use a combination of reverse-mode AD (to compute the vector-Jacobian product) and forward-mode AD (to compute the Jacobian-vector product) to compute the NTK. Let's code that up:" |
| 259 | + "Let's code that up:" |
251 | 260 | ]
|
252 | 261 | },
|
253 | 262 | {
|
|
257 | 266 | "metadata": {},
|
258 | 267 | "outputs": [],
|
259 | 268 | "source": [
|
260 |
| - "def empirical_ntk_implicit(func, params, x1, x2, compute='full'):\n", |
| 269 | + "def empirical_ntk_ntk_vps(func, params, x1, x2, compute='full'):\n", |
261 | 270 | " def get_ntk(x1, x2):\n",
|
262 | 271 | " def func_x1(params):\n",
|
263 | 272 | " return func(params, x1)\n",
|
|
280 | 289 | " return vmap(get_ntk_slice)(basis)\n",
|
281 | 290 | " \n",
|
282 | 291 | " # get_ntk(x1, x2) computes the NTK for a single data point x1, x2\n",
|
283 |
| - " # Since the x1, x2 inputs to empirical_ntk_implicit are batched,\n", |
| 292 | + " # Since the x1, x2 inputs to empirical_ntk_ntk_vps are batched,\n", |
284 | 293 | " # we actually wish to compute the NTK between every pair of data points\n",
|
285 | 294 | " # between {x1} and {x2}. That's what the vmaps here do.\n",
|
286 | 295 | " result = vmap(vmap(get_ntk, (None, 0)), (0, None))(x1, x2)\n",
|
|
300 | 309 | "metadata": {},
|
301 | 310 | "outputs": [],
|
302 | 311 | "source": [
|
303 |
| - "result_implicit = empirical_ntk_implicit(fnet_single, params, x_test, x_train)\n", |
304 |
| - "result_explicit = empirical_ntk(fnet_single, params, x_test, x_train)\n", |
305 |
| - "assert torch.allclose(result_implicit, result_explicit, atol=1e-5)" |
| 312 | + "result_from_jacobian_contraction = empirical_ntk_jacobian_contraction(fnet_single, params, x_test, x_train)\n", |
| 313 | + "result_from_ntk_vps = empirical_ntk_ntk_vps(fnet_single, params, x_test, x_train)\n", |
| 314 | + "assert torch.allclose(result_from_jacobian_contraction, result_from_ntk_vps, atol=1e-5)" |
306 | 315 | ]
|
307 | 316 | },
|
308 | 317 | {
|
309 | 318 | "cell_type": "markdown",
|
310 | 319 | "id": "84253466-971d-4475-999c-fe3de6bd25b5",
|
311 | 320 | "metadata": {},
|
312 | 321 | "source": [
|
313 |
| - "Our code for `empirical_ntk_implicit` looks like a direct translation from the math above! This showcases the power of function transforms: good luck trying to write an efficient version of the above using stock PyTorch." |
| 322 | + "Our code for `empirical_ntk_ntk_vps` looks like a direct translation from the math above! This showcases the power of function transforms: good luck trying to write an efficient version of the above using stock PyTorch.\n", |
| 323 | + "\n", |
| 324 | + "The asymptotic time complexity of this method is $N^2 O [FP]$, where $N$ is the batch size of $x_1$ and $x_2$, $O$ is the model's output size, and $[FP]$ is the cost of a single forward pass through the model. Hence this method performs more forward passes through the network than method 1, Jacobian contraction ($N^2 O$ instead of $N O$), but avoids the contraction cost altogether (no $N^2 O^2 P$ term, where $P$ is the total number of model's parameters). Therefore, this method is preferable when $O P$ is large relative to $[FP]$, such as fully-connected (not convolutional) models with many outputs $O$. Memory-wise, both methods should be comparable. See section 3.3 in [Fast Finite Width Neural Tangent Kernel](https://arxiv.org/abs/2206.08720) for details." |
314 | 325 | ]
|
315 | 326 | }
|
316 | 327 | ],
|
|
0 commit comments