Skip to content
This repository was archived by the owner on Aug 21, 2025. It is now read-only.

Commit d6b7f86

Browse files
authored
Add complexities and references for NTK implementations. (#907)
* Add complexities and references for NTK implementations. * Fix result names; rename "outer product" -> "matrix product". * Fix names
1 parent 137e4e1 commit d6b7f86

File tree

1 file changed

+38
-27
lines changed

1 file changed

+38
-27
lines changed

notebooks/neural_tangent_kernels.ipynb

Lines changed: 38 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
"source": [
88
"# Neural Tangent Kernels\n",
99
"\n",
10-
"The neural tangent kernel (NTK) is a kernel that describes [how a neural network evolves during training](https://en.wikipedia.org/wiki/Neural_tangent_kernel). There has been a lot of research around it [in recent years](https://arxiv.org/abs/1806.07572). This tutorial, inspired by the implementation of [NTKs in JAX](https://github.com/google/neural-tangents), demonstrates how to easily compute this quantity using functorch."
10+
"The neural tangent kernel (NTK) is a kernel that describes [how a neural network evolves during training](https://en.wikipedia.org/wiki/Neural_tangent_kernel). There has been a lot of research around it [in recent years](https://arxiv.org/abs/1806.07572). This tutorial, inspired by the implementation of [NTKs in JAX](https://github.com/google/neural-tangents) (see [Fast Finite Width Neural Tangent Kernel](https://arxiv.org/abs/2206.08720) for details), demonstrates how to easily compute this quantity using functorch."
1111
]
1212
},
1313
{
@@ -79,7 +79,7 @@
7979
"\n",
8080
"functorch transforms operate on functions. In particular, to compute the NTK, we will need a function that accepts the parameters of the model and a single input (as opposed to a batch of inputs!) and returns a single output.\n",
8181
"\n",
82-
"We'll use functorch's make_functional to accomplish the first step. If your module has buffers, you'll want to use make_functional_with_buffers instead."
82+
"We'll use functorch's `make_functional` to accomplish the first step. If your module has buffers, you'll want to use `make_functional_with_buffers` instead."
8383
]
8484
},
8585
{
@@ -117,13 +117,15 @@
117117
"id": "62bc6b5a-31fa-411e-8069-e6c1f6d05248",
118118
"metadata": {},
119119
"source": [
120-
"## Compute the NTK: method 1\n",
120+
"## Compute the NTK: method 1 (Jacobian contraction)\n",
121121
"\n",
122-
"We're ready to compute the empirical NTK. The empirical NTK for two data points `x1` and `x2` is defined as an inner product between the Jacobian of the model evaluated at `x1` and the Jacobian of the model evaluated at `x2`:\n",
122+
"We're ready to compute the empirical NTK. The empirical NTK for two data points $x_1$ and $x_2$ is defined as the matrix product between the Jacobian of the model evaluated at $x_1$ and the Jacobian of the model evaluated at $x_2$:\n",
123123
"\n",
124-
"$$J_{net}(x1) \\cdot J_{net}^T(x2)$$\n",
124+
"$$J_{net}(x_1) J_{net}^T(x_2)$$\n",
125125
"\n",
126-
"In the batched case where `x1` is a batch of data points and `x2` is a batch of data points, then we want the inner product between the Jacobians of all combinations of data points from `x1` and `x2`. Here's how to compute the NTK in the batched case:"
126+
"In the batched case where $x_1$ is a batch of data points and $x_2$ is a batch of data points, then we want the matrix product between the Jacobians of all combinations of data points from $x_1$ and $x_2$.\n",
127+
"\n",
128+
"The first method consists of doing just that - computing the two Jacobians, and contracting them. Here's how to compute the NTK in the batched case:"
127129
]
128130
},
129131
{
@@ -133,7 +135,7 @@
133135
"metadata": {},
134136
"outputs": [],
135137
"source": [
136-
"def empirical_ntk(fnet_single, params, x1, x2):\n",
138+
"def empirical_ntk_jacobian_contraction(fnet_single, params, x1, x2):\n",
137139
" # Compute J(x1)\n",
138140
" jac1 = vmap(jacrev(fnet_single), (None, 0))(params, x1)\n",
139141
" jac1 = [j.flatten(2) for j in jac1]\n",
@@ -163,7 +165,7 @@
163165
}
164166
],
165167
"source": [
166-
"result = empirical_ntk(fnet_single, params, x_train, x_test)\n",
168+
"result = empirical_ntk_jacobian_contraction(fnet_single, params, x_train, x_test)\n",
167169
"print(result.shape)"
168170
]
169171
},
@@ -182,7 +184,7 @@
182184
"metadata": {},
183185
"outputs": [],
184186
"source": [
185-
"def empirical_ntk(fnet_single, params, x1, x2, compute='full'):\n",
187+
"def empirical_ntk_jacobian_contraction(fnet_single, params, x1, x2, compute='full'):\n",
186188
" # Compute J(x1)\n",
187189
" jac1 = vmap(jacrev(fnet_single), (None, 0))(params, x1)\n",
188190
" jac1 = [j.flatten(2) for j in jac1]\n",
@@ -222,32 +224,39 @@
222224
}
223225
],
224226
"source": [
225-
"result = empirical_ntk(fnet_single, params, x_train, x_test, 'trace')\n",
227+
"result = empirical_ntk_jacobian_contraction(fnet_single, params, x_train, x_test, 'trace')\n",
226228
"print(result.shape)"
227229
]
228230
},
231+
{
232+
"cell_type": "markdown",
233+
"id": "6c941e5d-51d7-47b2-80ee-edcd4aee6aaa",
234+
"metadata": {},
235+
"source": [
236+
"The asymptotic time complexity of this method is $N O [FP]$ (time to compute the Jacobians) $ + N^2 O^2 P$ (time to contract the Jacobians), where $N$ is the batch size of $x_1$ and $x_2$, $O$ is the model's output size, $P$ is the total number of parameters, and $[FP]$ is the cost of a single forward pass through the model. See section section 3.2 in [Fast Finite Width Neural Tangent Kernel](https://arxiv.org/abs/2206.08720) for details."
237+
]
238+
},
229239
{
230240
"cell_type": "markdown",
231241
"id": "6c931e5d-51d7-47b2-80ee-ddcd4aee6aaa",
232242
"metadata": {},
233243
"source": [
234-
"## Compute the NTK: method 2\n",
244+
"## Compute the NTK: method 2 (NTK-vector products)\n",
235245
"\n",
236-
"The next method we will discuss is a way to compute the NTK implicitly. This has different tradeoffs compared to the previous one and it is generally more efficient when your model has large parameters; we recommend trying out both methods to see which works better.\n",
246+
"The next method we will discuss is a way to compute the NTK using NTK-vector products.\n",
237247
"\n",
238-
"Here's our definition of NTK:\n",
248+
"This method reformulates NTK as a stack of NTK-vector products applied to columns of an identity matrix $I_O$ of size $O\\times O$ (where $O$ is the output size of the model):\n",
239249
"\n",
240-
"$$J_{net}(x1) \\cdot J_{net}^T(x2)$$\n",
250+
"$$J_{net}(x_1) J_{net}^T(x_2) = J_{net}(x_1) J_{net}^T(x_2) I_{O} = \\left[J_{net}(x_1) \\left[J_{net}^T(x_2) e_o\\right]\\right]_{o=1}^{O},$$\n",
251+
"where $e_o\\in \\mathbb{R}^O$ are column vectors of the identity matrix $I_O$.\n",
241252
"\n",
242-
"The implicit computation reformulates the problem by adding an identity matrix and rearranging the matrix-multiplies:\n",
253+
"- Let $\\textrm{vjp}_o = J_{net}^T(x_2) e_o$. We can use a vector-Jacobian product to compute this.\n",
254+
"- Now, consider $J_{net}(x_1) \\textrm{vjp}_o$. This is a Jacobian-vector product!\n",
255+
"- Finally, we can run the above computation in parallel over all columns $e_o$ of $I_O$ using `vmap`.\n",
243256
"\n",
244-
"$$= J_{net}(x1) \\cdot J_{net}^T(x2) \\cdot I$$\n",
245-
"$$= (J_{net}(x1) \\cdot (J_{net}^T(x2) \\cdot I))$$\n",
257+
"This suggests that we can use a combination of reverse-mode AD (to compute the vector-Jacobian product) and forward-mode AD (to compute the Jacobian-vector product) to compute the NTK.\n",
246258
"\n",
247-
"- Let $vjps = (J_{net}^T(x2) \\cdot I)$. We can use a vector-Jacobian product to compute this.\n",
248-
"- Now, consider $J_{net}(x1) \\cdot vjps$. This is a Jacobian-vector product!\n",
249-
"\n",
250-
"This suggests that we can use a combination of reverse-mode AD (to compute the vector-Jacobian product) and forward-mode AD (to compute the Jacobian-vector product) to compute the NTK. Let's code that up:"
259+
"Let's code that up:"
251260
]
252261
},
253262
{
@@ -257,7 +266,7 @@
257266
"metadata": {},
258267
"outputs": [],
259268
"source": [
260-
"def empirical_ntk_implicit(func, params, x1, x2, compute='full'):\n",
269+
"def empirical_ntk_ntk_vps(func, params, x1, x2, compute='full'):\n",
261270
" def get_ntk(x1, x2):\n",
262271
" def func_x1(params):\n",
263272
" return func(params, x1)\n",
@@ -280,7 +289,7 @@
280289
" return vmap(get_ntk_slice)(basis)\n",
281290
" \n",
282291
" # get_ntk(x1, x2) computes the NTK for a single data point x1, x2\n",
283-
" # Since the x1, x2 inputs to empirical_ntk_implicit are batched,\n",
292+
" # Since the x1, x2 inputs to empirical_ntk_ntk_vps are batched,\n",
284293
" # we actually wish to compute the NTK between every pair of data points\n",
285294
" # between {x1} and {x2}. That's what the vmaps here do.\n",
286295
" result = vmap(vmap(get_ntk, (None, 0)), (0, None))(x1, x2)\n",
@@ -300,17 +309,19 @@
300309
"metadata": {},
301310
"outputs": [],
302311
"source": [
303-
"result_implicit = empirical_ntk_implicit(fnet_single, params, x_test, x_train)\n",
304-
"result_explicit = empirical_ntk(fnet_single, params, x_test, x_train)\n",
305-
"assert torch.allclose(result_implicit, result_explicit, atol=1e-5)"
312+
"result_from_jacobian_contraction = empirical_ntk_jacobian_contraction(fnet_single, params, x_test, x_train)\n",
313+
"result_from_ntk_vps = empirical_ntk_ntk_vps(fnet_single, params, x_test, x_train)\n",
314+
"assert torch.allclose(result_from_jacobian_contraction, result_from_ntk_vps, atol=1e-5)"
306315
]
307316
},
308317
{
309318
"cell_type": "markdown",
310319
"id": "84253466-971d-4475-999c-fe3de6bd25b5",
311320
"metadata": {},
312321
"source": [
313-
"Our code for `empirical_ntk_implicit` looks like a direct translation from the math above! This showcases the power of function transforms: good luck trying to write an efficient version of the above using stock PyTorch."
322+
"Our code for `empirical_ntk_ntk_vps` looks like a direct translation from the math above! This showcases the power of function transforms: good luck trying to write an efficient version of the above using stock PyTorch.\n",
323+
"\n",
324+
"The asymptotic time complexity of this method is $N^2 O [FP]$, where $N$ is the batch size of $x_1$ and $x_2$, $O$ is the model's output size, and $[FP]$ is the cost of a single forward pass through the model. Hence this method performs more forward passes through the network than method 1, Jacobian contraction ($N^2 O$ instead of $N O$), but avoids the contraction cost altogether (no $N^2 O^2 P$ term, where $P$ is the total number of model's parameters). Therefore, this method is preferable when $O P$ is large relative to $[FP]$, such as fully-connected (not convolutional) models with many outputs $O$. Memory-wise, both methods should be comparable. See section 3.3 in [Fast Finite Width Neural Tangent Kernel](https://arxiv.org/abs/2206.08720) for details."
314325
]
315326
}
316327
],

0 commit comments

Comments
 (0)