-
Notifications
You must be signed in to change notification settings - Fork 115
Neural network layers example #581
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
RobertStanforth
wants to merge
5
commits into
google-research:main
Choose a base branch
from
RobertStanforth:main
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
9e2266a
Corrected off-by-one error in bounds check on applicability of conv f…
4e68e2f
Merge branch 'google-research:main' into main
RobertStanforth ec11ecf
Prototype layers library using type system to manage layers' params a…
55be172
Alternative prototype layers library, using existential types to hide…
59bb740
Corrections to trainModel in existentially-typed layers example
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,277 @@ | ||
| ' # Neural Networks | ||
|
|
||
| ' ## NN Prelude | ||
|
|
||
| ' This ReLU implementation is for one neuron. | ||
| Use `map relu` to apply it to a 1D table, `map (map relu)` for 2D, etc. | ||
|
|
||
| def relu (input : Float) : Float = | ||
| select (input > 0.0) input 0.0 | ||
|
|
||
| ' A pair of vector spaces is also a vector space. | ||
|
|
||
| instance [Add a, Add b] Add (a & b) | ||
| add = \(a, b) (c, d). ( (a + c), (b + d)) | ||
| sub = \(a, b) (c, d). ( (a - c), (b - d)) | ||
| zero = (zero, zero) | ||
|
|
||
| instance [VSpace a, VSpace b] VSpace (a & b) | ||
| scaleVec = \ s (a, b) . (scaleVec s a, scaleVec s b) | ||
|
|
||
| ' Layer type, describing a function with trainable side-parameters. | ||
| This may represent a primitive layer (e.g. dense), a composition of layers | ||
| (e.g. resnet_block), or an entire network. | ||
| `forward` is the function implementing the layer computation. | ||
| `init` provides initial values of the parameters, given a random key. | ||
|
|
||
| data Layer inp:Type out:Type params:Type = | ||
| AsLayer {forward:(params -> inp -> out) & init:(Key -> params)} | ||
|
|
||
| ' Convenience functions to extract `forward` and `init` functions. | ||
|
|
||
| def forward (l:Layer i o p) (p : p) (x : i): o = | ||
| (AsLayer l') = l | ||
| (getAt #forward l') p x | ||
|
|
||
| def init (l:Layer i o p) (k:Key) : p = | ||
| (AsLayer l') = l | ||
| (getAt #init l') k | ||
|
|
||
| ' Adapt a pure function into a (parameterless) `Layer`. | ||
| This is unused and for illustration only: we will see below how to apply | ||
| pure functions directly with `trace_map`. | ||
|
|
||
| def as_layer (f:u->v) : Layer u v Unit = | ||
| AsLayer { | ||
| forward = \ () x . f x, | ||
| init = \_ . () | ||
| } | ||
|
|
||
| ' ## Layers | ||
|
|
||
| ' Dense layer. | ||
|
|
||
| def DenseParams (a:Type) (b:Type) : Type = | ||
| ((a=>b=>Float) & (b=>Float)) | ||
|
|
||
| def dense (a:Type) (b:Type) : Layer (a=>Float) (b=>Float) (DenseParams a b) = | ||
| AsLayer { | ||
| forward = (\ ((weight, bias)) x . | ||
| for j. (bias.j + sum for i. weight.i.j * x.i)), | ||
| init = arb | ||
| } | ||
|
|
||
|
|
||
| ' CNN layer. | ||
|
|
||
| def CNNParams (inc:Type) (outc:Type) (kw:Int) (kh:Int) : Type = | ||
| ((outc=>inc=>Fin kh=>Fin kw=>Float) & | ||
| (outc=>Float)) | ||
|
|
||
| def conv2d (x:inc=>(Fin h)=>(Fin w)=>Float) | ||
| (kernel:outc=>inc=>(Fin kh)=>(Fin kw)=>Float) : | ||
| outc=>(Fin h)=>(Fin w)=>Float = | ||
| for o i j. | ||
| (i', j') = (ordinal i, ordinal j) | ||
| case (i' + kh) <= h && (j' + kw) <= w of | ||
| True -> | ||
| sum for (ki, kj, inp). | ||
| (di, dj) = (fromOrdinal (Fin h) (i' + (ordinal ki)), | ||
| fromOrdinal (Fin w) (j' + (ordinal kj))) | ||
| x.inp.di.dj * kernel.o.inp.ki.kj | ||
| False -> zero | ||
|
|
||
| def cnn (h:Int) ?-> (w:Int) ?-> (inc:Type) (outc:Type) (kw:Int) (kh:Int) : | ||
| Layer (inc=>(Fin h)=>(Fin w)=>Float) | ||
| (outc=>(Fin h)=>(Fin w)=>Float) | ||
| (CNNParams inc outc kw kh) = | ||
| AsLayer { | ||
| forward = (\ (weight, bias) x. for o i j . (conv2d x weight).o.i.j + bias.o), | ||
| init = arb | ||
| } | ||
|
|
||
| ' ## Tracing | ||
|
|
||
| ' A tracer is an object that we pass through a network function. The tracer | ||
| is used as if it were an actual input to the network, and invoking constituent | ||
| layers results in new tracers denoting intermediate values of the network. | ||
| A tracer is implemented as a partial neural network, from the original inputs | ||
| to the current intermediate values. | ||
| Starting with an identity layer as the input tracer, tracing thus accumulates | ||
| an output tracer that implements the full network. | ||
|
|
||
| def trace (f : Layer inp inp Unit -> Layer inp out params) : Layer inp out params = | ||
| tracer = AsLayer { | ||
| forward = \() x. x, | ||
| init = \key. () | ||
| } | ||
| f tracer | ||
|
|
||
| ' Converts a pure function (e.g. `map relu`) to a callable acting on tracers. | ||
|
|
||
| def trace_map (f : inp -> out) (tracer : Layer ph inp ps) : Layer ph out ps = | ||
| AsLayer { | ||
| forward = \w x. f (forward tracer w x), | ||
| init = \key. init tracer key | ||
| } | ||
|
|
||
| ' Converts a two-arg pure function (e.g. `add`) to a callable acting on tracers. | ||
|
|
||
| def trace_map2 (f : inp0 -> inp1 -> out) | ||
| (tracer0 : Layer ph inp0 ps0) (tracer1 : Layer ph inp1 ps1) : Layer ph out (ps0&ps1) = | ||
| AsLayer { | ||
| forward = \(w0, w1) x. f (forward tracer0 w0 x) (forward tracer1 w1 x), | ||
| init = (\key. | ||
| [k0, k1] = splitKey key | ||
| (init tracer0 k0, init tracer1 k1)) | ||
| } | ||
|
|
||
| ' Adapts an existing `Layer` into a callable that can be invoked within a | ||
| larger layer being traced. | ||
|
|
||
| def callable (layer : Layer inp out params) (tracer : Layer ph inp ps) : Layer ph out (params & ps) = | ||
| AsLayer { | ||
| forward = \w x. forward layer (fst w) $ forward tracer (snd w) x, | ||
| init = (\key. | ||
| [k0, k1] = splitKey key | ||
| (init layer k0, init tracer k1)) | ||
| } | ||
|
|
||
| ' ## Networks | ||
|
|
||
| ' MLP | ||
|
|
||
| mlp = trace \x. | ||
| dense1 = callable $ dense (Fin 2) (Fin 25) | ||
| x = dense1 x | ||
| x = trace_map (map relu) x | ||
| dense3 = callable $ dense _ (Fin 2) | ||
| x = dense3 x | ||
| x | ||
|
|
||
| w_mlp = init mlp (newKey 1) | ||
| :t w_mlp | ||
| :p forward mlp w_mlp (for _. 0.) | ||
|
|
||
| ' ResNet - incorrect first attempt | ||
|
|
||
| ' The following does not work correctly because the (last assigned) value of `x` | ||
| ' is consumed twice, directly or indirectly by both inputs to `trace_map2`. | ||
| ' Because it's a tracer, it leads to two independent copies of the `dense1` | ||
| ' being created, one for each branch. | ||
|
|
||
| resnet_incorrect = trace \x. | ||
| dense1 = callable $ dense (Fin 2) (Fin 10) | ||
| x = dense1 x | ||
| x = trace_map (map relu) x | ||
| -- This value of `x` will be consumed twice. | ||
|
|
||
| dense3 = callable $ dense _ (Fin 25) | ||
| y = dense3 x | ||
| y = trace_map (map relu) y | ||
| dense5 = callable $ dense _ (Fin 10) | ||
| y = dense5 y | ||
| y = trace_map2 add y x | ||
|
|
||
| y = trace_map (map relu) y | ||
| dense7 = callable $ dense _ (Fin 2) | ||
| y = dense7 y | ||
| y | ||
|
|
||
| w_resnet_bad = init resnet_incorrect (newKey 1) | ||
| :t w_resnet_bad | ||
| :p forward resnet_incorrect w_resnet_bad (for _. 0.) | ||
|
|
||
|
|
||
| ' ResNet - nested sequential | ||
|
|
||
| ' In `resnet_block` below, `x` is used twice, but that is safe because it's the | ||
| ' input tracer with no network parameters yet. | ||
|
|
||
| ' We'd like to just use `_` for the return type to be `_`. | ||
| ' Unfortunately that currently leads to a "leaked type variable `a`" error. | ||
|
|
||
| def resnet_block (a:Type) | ||
| : Layer (a=>Float) (a=>Float) (((DenseParams _ a) & (DenseParams a _) & Unit) & Unit) | ||
| = trace \x. | ||
| dense1 = callable $ dense a (Fin 25) | ||
| y = dense1 x | ||
| y = trace_map (map relu) y | ||
| dense3 = callable $ dense _ a | ||
| y = dense3 y | ||
| y = trace_map2 add y x | ||
| y | ||
|
|
||
| resnet = trace \x. | ||
| dense1 = callable $ dense (Fin 2) (Fin 10) | ||
| x = dense1 x | ||
| x = trace_map (map relu) x | ||
| block3 = callable $ resnet_block _ | ||
| x = block3 x | ||
| x = trace_map (map relu) x | ||
| dense7 = callable $ dense _ (Fin 2) | ||
| x = dense7 x | ||
| x | ||
|
|
||
| w_resnet = init resnet (newKey 1) | ||
| :t w_resnet | ||
| :p forward resnet w_resnet (for _. 0.) | ||
|
|
||
|
|
||
| ' ## Training | ||
|
|
||
| ' Train a multiclass classifier with minibatch SGD | ||
| ' `minibatch * minibatches = batch` | ||
|
|
||
| def split (x: batch=>v) : minibatches=>minibatch=>v = | ||
| for b j. x.((ordinal (b,j))@batch) | ||
|
|
||
| def trainClass [VSpace p] (model: Layer a (b=>Float) p) | ||
| (x: batch=>a) | ||
| (y: batch=>b) | ||
| (epochs : Type) | ||
| (minibatch : Type) | ||
| (minibatches : Type) : | ||
| (epochs => p & epochs => Float) = | ||
| xs : minibatches => minibatch => a = split x | ||
| ys : minibatches => minibatch => b = split y | ||
| unzip $ withState (init model $ newKey 1) $ \params . | ||
| for _ : epochs. | ||
| loss = sum $ for b : minibatches. | ||
| (loss, gradfn) = vjp (\ params. | ||
| -sum for j. | ||
| logits = forward model params xs.b.j | ||
| (logsoftmax logits).(ys.b.j)) (get params) | ||
| gparams = gradfn 1.0 | ||
| params := (get params) - scaleVec (0.05 / (IToF 100)) gparams | ||
| loss | ||
| (get params, loss) | ||
|
|
||
| ' Sextant classification dataset. | ||
|
|
||
| [k1, k2] = splitKey $ newKey 1 | ||
| x1 : Fin 400 => Float = arb k1 | ||
| x2 : Fin 400 => Float = arb k2 | ||
| y = for i. case ((x1.i * x1.i * x1.i - 3. * x1.i * x2.i * x2.i) > 0.) of | ||
| True -> 1 | ||
| False -> 0 | ||
| xs = for i. [x1.i, x2.i] | ||
|
|
||
| import plot | ||
| :html showPlot $ xycPlot x1 x2 $ for i. IToF y.i | ||
|
|
||
| ' Train classifier on this dataset. | ||
|
|
||
| -- model = mlp | ||
| model = resnet | ||
|
|
||
| (all_params, losses) = trainClass model xs (for i. (y.i @ (Fin 2))) (Fin 3000) (Fin 50) (Fin 8) | ||
|
|
||
| ' Classification landscape, evolving over training time. Colour denotes softmax. | ||
|
|
||
| span = linspace (Fin 15) (-1.0) (1.0) | ||
| tests = for h : (Fin 100). for i . for j. | ||
| r = softmax $ forward model all_params.((ordinal h * 30)@_) [span.j, -span.i] | ||
| [r.(1@_), 0.5*r.(1@_), r.(0@_)] | ||
|
|
||
| :html imseqshow tests | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Might make sense to assert that
minibatches * minibatch = batch, because this function will happily throw away data.