Skip to content
This repository was archived by the owner on May 6, 2025. It is now read-only.

Commit 0661929

Browse files
Sohl-DicksteinSam Schoenholz
authored andcommitted
Update docs now that improved standard parameterization note is on arXiv.
PiperOrigin-RevId: 291473673
1 parent 7c50657 commit 0661929

File tree

2 files changed

+14
-7
lines changed

2 files changed

+14
-7
lines changed

README.md

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@ import neural_tangents as nt # 64-bit precision enabled
245245
We remark the following differences between our library and the JAX one.
246246

247247
* All `nt.stax` layers are instantiated with a function call, i.e. `nt.stax.Relu()` vs `jax.experimental.stax.Relu`.
248-
* All layers with trainable parameters use the _NTK parameterization_ by default (see [[10]](#5-neural-tangent-kernel-convergence-and-generalization-in-neural-networks-neurips-2018-arthur-jacot-franck-gabriel-clément-hongler), Remark 1). However, Dense and Conv layers also support the _standard parameterization_ via a `parameterization` keyword argument. <!-- TODO(jaschasd) add link to note deriving NTK for standard parameterization -->
248+
* All layers with trainable parameters use the _NTK parameterization_ by default (see [[10]](#5-neural-tangent-kernel-convergence-and-generalization-in-neural-networks-neurips-2018-arthur-jacot-franck-gabriel-clément-hongler), Remark 1). However, Dense and Conv layers also support the _standard parameterization_ via a `parameterization` keyword argument (see [[15]](#15-on-the-infinite-width-limit-of-neural-networks-with-a-standard-parameterization)).
249249
* `nt.stax` and `jax.experimental.stax` may have different layers and options available (for example `nt.stax` layers support `CIRCULAR` padding, but only `NHWC` data format).
250250

251251
### Python 2 is not supported
@@ -358,10 +358,10 @@ a small dataset using a small learning rate.
358358
Neural Tangents has been used in the following papers:
359359

360360

361-
* [Disentangling Trainability and Generalization in Deep Learning](https://arxiv.org/abs/1912.13053) \
361+
* [Disentangling Trainability and Generalization in Deep Learning.](https://arxiv.org/abs/1912.13053) \
362362
Lechao Xiao, Jeffrey Pennington, Samuel S. Schoenholz
363363

364-
* [Information in Infinite Ensembles of Infinitely-Wide Neural Networks](https://arxiv.org/abs/1911.09189) \
364+
* [Information in Infinite Ensembles of Infinitely-Wide Neural Networks.](https://arxiv.org/abs/1911.09189) \
365365
Ravid Shwartz-Ziv, Alexander A. Alemi
366366

367367
* [Training Dynamics of Deep Networks using Stochastic Gradient Descent via Neural Tangent Kernel.](https://arxiv.org/abs/1905.13654) \
@@ -372,6 +372,9 @@ Descent.](https://arxiv.org/abs/1902.06720) \
372372
Jaehoon Lee*, Lechao Xiao*, Samuel S. Schoenholz, Yasaman Bahri, Roman Novak, Jascha
373373
Sohl-Dickstein, Jeffrey Pennington
374374

375+
* [On the Infinite Width Limit of Neural Networks with a Standard Parameterization.](https://arxiv.org/pdf/2001.07301.pdf) \
376+
Jascha Sohl-Dickstein, Roman Novak, Samuel S. Schoenholz, Jaehoon Lee
377+
375378
Please let us know if you make use of the code in a publication and we'll add it
376379
to the list!
377380

@@ -423,3 +426,5 @@ If you use the code in a publication, please cite the repo using the .bib,
423426
##### [13] [Mean Field Residual Networks: On the Edge of Chaos.](https://arxiv.org/abs/1712.08969) *NeurIPS 2017.* Greg Yang, Samuel S. Schoenholz
424427

425428
##### [14] [Wide Residual Networks.](https://arxiv.org/abs/1605.07146) *BMVC 2018.* Sergey Zagoruyko, Nikos Komodakis
429+
430+
##### [15] [On the Infinite Width Limit of Neural Networks with a Standard Parameterization.](https://arxiv.org/pdf/2001.07301.pdf) *arXiv 2020.* Jascha Sohl-Dickstein, Roman Novak, Samuel S. Schoenholz, Jaehoon Lee

neural_tangents/stax.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,9 @@
2525
similarly to `init_fn` and `apply_fn`.
2626
2727
2) In layers with random weights, NTK parameterization is used by default
28-
(https://arxiv.org/abs/1806.07572, page 3). Standard parameterization can
29-
be specified for Conv and Dense layers by a keyword argument.
28+
(https://arxiv.org/abs/1806.07572, page 3). Standard parameterization
29+
(https://arxiv.org/abs/2001.07301) can be specified for Conv and Dense layers
30+
by a keyword argument.
3031
3132
3) Some functionality may be missing (e.g. `BatchNorm`), and some may be present
3233
only in our library (e.g. `CIRCULAR` padding, `LayerNorm`, `GlobalAvgPool`,
@@ -951,8 +952,9 @@ def Dense(out_dim,
951952
Under ntk parameterization (https://arxiv.org/abs/1806.07572, page 3),
952953
weights and biases are initialized as W_ij ~ N(0,1), b_i ~ N(0,1), and
953954
the finite width layer equation is z_i = W_std / sqrt([width]) sum_j
954-
W_ij x_j + b_std b_i Under standard parameterization, weights and biases
955-
are initialized as W_ij ~ N(0,W_std^2/[width]), b_i ~ N(0,b_std^2), and
955+
W_ij x_j + b_std b_i Under standard parameterization
956+
(https://arxiv.org/abs/2001.07301), weights and biases are initialized
957+
as W_ij ~ N(0,W_std^2/[width]), b_i ~ N(0,b_std^2), and
956958
the finite width layer equation is z_i = \sum_j W_ij x_j + b_i.
957959
958960

0 commit comments

Comments
 (0)