diff --git a/examples/nn_layers.dx b/examples/nn_layers.dx new file mode 100644 index 000000000..a4a1c75e4 --- /dev/null +++ b/examples/nn_layers.dx @@ -0,0 +1,277 @@ +' # Neural Networks + +' ## NN Prelude + +' This ReLU implementation is for one neuron. +Use `map relu` to apply it to a 1D table, `map (map relu)` for 2D, etc. + +def relu (input : Float) : Float = + select (input > 0.0) input 0.0 + +' A pair of vector spaces is also a vector space. + +instance [Add a, Add b] Add (a & b) + add = \(a, b) (c, d). ( (a + c), (b + d)) + sub = \(a, b) (c, d). ( (a - c), (b - d)) + zero = (zero, zero) + +instance [VSpace a, VSpace b] VSpace (a & b) + scaleVec = \ s (a, b) . (scaleVec s a, scaleVec s b) + +' Layer type, describing a function with trainable side-parameters. +This may represent a primitive layer (e.g. dense), a composition of layers +(e.g. resnet_block), or an entire network. +`forward` is the function implementing the layer computation. +`init` provides initial values of the parameters, given a random key. + +data Layer inp:Type out:Type params:Type = + AsLayer {forward:(params -> inp -> out) & init:(Key -> params)} + +' Convenience functions to extract `forward` and `init` functions. + +def forward (l:Layer i o p) (p : p) (x : i): o = + (AsLayer l') = l + (getAt #forward l') p x + +def init (l:Layer i o p) (k:Key) : p = + (AsLayer l') = l + (getAt #init l') k + +' Adapt a pure function into a (parameterless) `Layer`. +This is unused and for illustration only: we will see below how to apply +pure functions directly with `trace_map`. + +def as_layer (f:u->v) : Layer u v Unit = + AsLayer { + forward = \ () x . f x, + init = \_ . () + } + +' ## Layers + +' Dense layer. + +def DenseParams (a:Type) (b:Type) : Type = + ((a=>b=>Float) & (b=>Float)) + +def dense (a:Type) (b:Type) : Layer (a=>Float) (b=>Float) (DenseParams a b) = + AsLayer { + forward = (\ ((weight, bias)) x . + for j. (bias.j + sum for i. weight.i.j * x.i)), + init = arb + } + + +' CNN layer. + +def CNNParams (inc:Type) (outc:Type) (kw:Int) (kh:Int) : Type = + ((outc=>inc=>Fin kh=>Fin kw=>Float) & + (outc=>Float)) + +def conv2d (x:inc=>(Fin h)=>(Fin w)=>Float) + (kernel:outc=>inc=>(Fin kh)=>(Fin kw)=>Float) : + outc=>(Fin h)=>(Fin w)=>Float = + for o i j. + (i', j') = (ordinal i, ordinal j) + case (i' + kh) <= h && (j' + kw) <= w of + True -> + sum for (ki, kj, inp). + (di, dj) = (fromOrdinal (Fin h) (i' + (ordinal ki)), + fromOrdinal (Fin w) (j' + (ordinal kj))) + x.inp.di.dj * kernel.o.inp.ki.kj + False -> zero + +def cnn (h:Int) ?-> (w:Int) ?-> (inc:Type) (outc:Type) (kw:Int) (kh:Int) : + Layer (inc=>(Fin h)=>(Fin w)=>Float) + (outc=>(Fin h)=>(Fin w)=>Float) + (CNNParams inc outc kw kh) = + AsLayer { + forward = (\ (weight, bias) x. for o i j . (conv2d x weight).o.i.j + bias.o), + init = arb + } + +' ## Tracing + +' A tracer is an object that we pass through a network function. The tracer +is used as if it were an actual input to the network, and invoking constituent +layers results in new tracers denoting intermediate values of the network. +A tracer is implemented as a partial neural network, from the original inputs +to the current intermediate values. +Starting with an identity layer as the input tracer, tracing thus accumulates +an output tracer that implements the full network. + +def trace (f : Layer inp inp Unit -> Layer inp out params) : Layer inp out params = + tracer = AsLayer { + forward = \() x. x, + init = \key. () + } + f tracer + +' Converts a pure function (e.g. `map relu`) to a callable acting on tracers. + +def trace_map (f : inp -> out) (tracer : Layer ph inp ps) : Layer ph out ps = + AsLayer { + forward = \w x. f (forward tracer w x), + init = \key. init tracer key + } + +' Converts a two-arg pure function (e.g. `add`) to a callable acting on tracers. + +def trace_map2 (f : inp0 -> inp1 -> out) + (tracer0 : Layer ph inp0 ps0) (tracer1 : Layer ph inp1 ps1) : Layer ph out (ps0&ps1) = + AsLayer { + forward = \(w0, w1) x. f (forward tracer0 w0 x) (forward tracer1 w1 x), + init = (\key. + [k0, k1] = splitKey key + (init tracer0 k0, init tracer1 k1)) + } + +' Adapts an existing `Layer` into a callable that can be invoked within a +larger layer being traced. + +def callable (layer : Layer inp out params) (tracer : Layer ph inp ps) : Layer ph out (params & ps) = + AsLayer { + forward = \w x. forward layer (fst w) $ forward tracer (snd w) x, + init = (\key. + [k0, k1] = splitKey key + (init layer k0, init tracer k1)) + } + +' ## Networks + +' MLP + +mlp = trace \x. + dense1 = callable $ dense (Fin 2) (Fin 25) + x = dense1 x + x = trace_map (map relu) x + dense3 = callable $ dense _ (Fin 2) + x = dense3 x + x + +w_mlp = init mlp (newKey 1) +:t w_mlp +:p forward mlp w_mlp (for _. 0.) + +' ResNet - incorrect first attempt + +' The following does not work correctly because the (last assigned) value of `x` +' is consumed twice, directly or indirectly by both inputs to `trace_map2`. +' Because it's a tracer, it leads to two independent copies of the `dense1` +' being created, one for each branch. + +resnet_incorrect = trace \x. + dense1 = callable $ dense (Fin 2) (Fin 10) + x = dense1 x + x = trace_map (map relu) x + -- This value of `x` will be consumed twice. + + dense3 = callable $ dense _ (Fin 25) + y = dense3 x + y = trace_map (map relu) y + dense5 = callable $ dense _ (Fin 10) + y = dense5 y + y = trace_map2 add y x + + y = trace_map (map relu) y + dense7 = callable $ dense _ (Fin 2) + y = dense7 y + y + +w_resnet_bad = init resnet_incorrect (newKey 1) +:t w_resnet_bad +:p forward resnet_incorrect w_resnet_bad (for _. 0.) + + +' ResNet - nested sequential + +' In `resnet_block` below, `x` is used twice, but that is safe because it's the +' input tracer with no network parameters yet. + +' We'd like to just use `_` for the return type to be `_`. +' Unfortunately that currently leads to a "leaked type variable `a`" error. + +def resnet_block (a:Type) + : Layer (a=>Float) (a=>Float) (((DenseParams _ a) & (DenseParams a _) & Unit) & Unit) + = trace \x. + dense1 = callable $ dense a (Fin 25) + y = dense1 x + y = trace_map (map relu) y + dense3 = callable $ dense _ a + y = dense3 y + y = trace_map2 add y x + y + +resnet = trace \x. + dense1 = callable $ dense (Fin 2) (Fin 10) + x = dense1 x + x = trace_map (map relu) x + block3 = callable $ resnet_block _ + x = block3 x + x = trace_map (map relu) x + dense7 = callable $ dense _ (Fin 2) + x = dense7 x + x + +w_resnet = init resnet (newKey 1) +:t w_resnet +:p forward resnet w_resnet (for _. 0.) + + +' ## Training + +' Train a multiclass classifier with minibatch SGD +' `minibatch * minibatches = batch` + +def split (x: batch=>v) : minibatches=>minibatch=>v = + for b j. x.((ordinal (b,j))@batch) + +def trainClass [VSpace p] (model: Layer a (b=>Float) p) + (x: batch=>a) + (y: batch=>b) + (epochs : Type) + (minibatch : Type) + (minibatches : Type) : + (epochs => p & epochs => Float) = + xs : minibatches => minibatch => a = split x + ys : minibatches => minibatch => b = split y + unzip $ withState (init model $ newKey 1) $ \params . + for _ : epochs. + loss = sum $ for b : minibatches. + (loss, gradfn) = vjp (\ params. + -sum for j. + logits = forward model params xs.b.j + (logsoftmax logits).(ys.b.j)) (get params) + gparams = gradfn 1.0 + params := (get params) - scaleVec (0.05 / (IToF 100)) gparams + loss + (get params, loss) + +' Sextant classification dataset. + +[k1, k2] = splitKey $ newKey 1 +x1 : Fin 400 => Float = arb k1 +x2 : Fin 400 => Float = arb k2 +y = for i. case ((x1.i * x1.i * x1.i - 3. * x1.i * x2.i * x2.i) > 0.) of + True -> 1 + False -> 0 +xs = for i. [x1.i, x2.i] + +import plot +:html showPlot $ xycPlot x1 x2 $ for i. IToF y.i + +' Train classifier on this dataset. + +-- model = mlp +model = resnet + +(all_params, losses) = trainClass model xs (for i. (y.i @ (Fin 2))) (Fin 3000) (Fin 50) (Fin 8) + +' Classification landscape, evolving over training time. Colour denotes softmax. + +span = linspace (Fin 15) (-1.0) (1.0) +tests = for h : (Fin 100). for i . for j. + r = softmax $ forward model all_params.((ordinal h * 30)@_) [span.j, -span.i] + [r.(1@_), 0.5*r.(1@_), r.(0@_)] + +:html imseqshow tests diff --git a/examples/nn_layers_existential.dx b/examples/nn_layers_existential.dx new file mode 100644 index 000000000..e22910cef --- /dev/null +++ b/examples/nn_layers_existential.dx @@ -0,0 +1,298 @@ +' # Neural Networks + +' ## NN Prelude + +' This ReLU implementation is for one neuron. +Use `map relu` to apply it to a 1D table, `map (map relu)` for 2D, etc. + +def relu (input : Float) : Float = + select (input > 0.0) input 0.0 + +' A pair of vector spaces is also a vector space. + +instance [Add a, Add b] Add (a & b) + add = \(a, b) (c, d). ( (a + c), (b + d)) + sub = \(a, b) (c, d). ( (a - c), (b - d)) + zero = (zero, zero) + +instance [VSpace a, VSpace b] VSpace (a & b) + scaleVec = \ s (a, b) . (scaleVec s a, scaleVec s b) + +' Layer type, describing a function with trainable side-parameters. +This may represent a primitive layer (e.g. dense), a composition of layers +(e.g. resnet_block), or an entire network. +`forward` is the function implementing the layer computation. +`init` provides initial values of the parameters, given a random key. +This is defined as existential over the `params` type, so that the actual +structure of `params` can be hidden. + +data Layer inp:Type out:Type = + AsLayer params:Type vs:(VSpace params) {forward:(params -> inp -> out) & init:(Key -> params)} + +' Wrapper to infer the `vs`. "?=>" directs it to be inferred from all declared VSpace instances. + +def make_layer (params:Type) (vs:(VSpace params)) ?=> + (methods:{forward:(params -> inp -> out) & init:(Key -> params)}) : Layer inp out = + AsLayer params vs methods + +' Adapt a pure function into a (parameterless) `Layer`. +This is unused and for illustration only: we will see below how to apply +pure functions directly with `trace_map`. + +def as_layer (f:u->v) : Layer u v = + make_layer Unit { + forward = \ () x . f x, + init = \_ . () + } + +' ## Layers + +' Dense layer. + +def DenseParams (a:Type) (b:Type) : Type = + ((a=>b=>Float) & (b=>Float)) + +def dense (a:Type) (b:Type) : Layer (a=>Float) (b=>Float) = + make_layer (DenseParams a b) { + forward = (\ (weight, bias) x . + for j. (bias.j + sum for i. weight.i.j * x.i)), + init = arb + } + + +' CNN layer. + +def CNNParams (inc:Type) (outc:Type) (kw:Int) (kh:Int) : Type = + ((outc=>inc=>Fin kh=>Fin kw=>Float) & + (outc=>Float)) + +def conv2d (x:inc=>(Fin h)=>(Fin w)=>Float) + (kernel:outc=>inc=>(Fin kh)=>(Fin kw)=>Float) : + outc=>(Fin h)=>(Fin w)=>Float = + for o i j. + (i', j') = (ordinal i, ordinal j) + case (i' + kh) <= h && (j' + kw) <= w of + True -> + sum for (ki, kj, inp). + (di, dj) = (fromOrdinal (Fin h) (i' + (ordinal ki)), + fromOrdinal (Fin w) (j' + (ordinal kj))) + x.inp.di.dj * kernel.o.inp.ki.kj + False -> zero + +def cnn (h:Int) ?-> (w:Int) ?-> (inc:Type) (outc:Type) (kw:Int) (kh:Int) : + Layer (inc=>(Fin h)=>(Fin w)=>Float) + (outc=>(Fin h)=>(Fin w)=>Float) = + make_layer (CNNParams inc outc kw kh) { + forward = (\ (weight, bias) x. for o i j . (conv2d x weight).o.i.j + bias.o), + init = arb + } + +' ## Tracing + +' A tracer is an object that we pass through a network function. The tracer +is used as if it were an actual input to the network, and invoking constituent +layers results in new tracers denoting intermediate values of the network. +A tracer is implemented as a partial neural network, from the original inputs +to the current intermediate values. +Starting with an identity layer as the input tracer, tracing thus accumulates +an output tracer that implements the full network. + +def trace (f : Layer inp inp -> Layer inp out) : Layer inp out = + tracer = make_layer Unit { + forward = \() x. x, + init = \key. () + } + f tracer + +' Converts a pure function (e.g. `map relu`) to a callable acting on tracers. + +def trace_map (f : inp -> out) (tracer : Layer ph inp) : Layer ph out = + (AsLayer ps vs tracer') = tracer + %instance vsHint = vs + make_layer ps { + forward = \w x. f ((getAt #forward tracer') w x), + init = \key. (getAt #init tracer') key + } + +' Converts a two-arg pure function (e.g. `add`) to a callable acting on tracers. + +def trace_map2 (f : inp0 -> inp1 -> out) + (tracer0 : Layer ph inp0) (tracer1 : Layer ph inp1) : Layer ph out = + (AsLayer ps0 vs0 tracer0') = tracer0 + (AsLayer ps1 vs1 tracer1') = tracer1 + %instance vs0Hint = vs0 + %instance vs1Hint = vs1 + make_layer (ps0&ps1) { + forward = \(w0, w1) x. f ((getAt #forward tracer0') w0 x) ((getAt #forward tracer1') w1 x), + init = (\key. + [k0, k1] = splitKey key + ((getAt #init tracer0') k0, (getAt #init tracer1') k1)) + } + +' Adapts an existing `Layer` into a callable that can be invoked within a +larger layer being traced. + +def callable (layer : Layer inp out) (tracer : Layer ph inp) : Layer ph out = + (AsLayer params vsl layer') = layer + (AsLayer ps vs tracer') = tracer + %instance vslHint = vsl + %instance vsHint = vs + make_layer (params & ps) { + forward = \w x. (getAt #forward layer') (fst w) $ (getAt #forward tracer') (snd w) x, + init = (\key. + [k0, k1] = splitKey key + ((getAt #init layer') k0, (getAt #init tracer') k1)) + } + +' ## Networks + +' MLP + +mlp = trace \x. + dense1 = callable $ dense (Fin 2) (Fin 25) + x = dense1 x + x = trace_map (map relu) x + dense3 = callable $ dense _ (Fin 2) + x = dense3 x + x + +(AsLayer _ _ mlp') = mlp +w_mlp = (getAt #init mlp') (newKey 1) +:t w_mlp +:p (getAt #forward mlp') w_mlp (for _. 0.) + +' ResNet - incorrect first attempt + +' The following does not work correctly because the (last assigned) value of `x` +' is consumed twice, directly or indirectly by both inputs to `trace_map2`. +' Because it's a tracer, it leads to two independent copies of the `dense1` +' being created, one for each branch. + +resnet_incorrect = trace \x. + dense1 = callable $ dense (Fin 2) (Fin 10) + x = dense1 x + x = trace_map (map relu) x + -- This value of `x` will be consumed twice. + + dense3 = callable $ dense _ (Fin 25) + y = dense3 x + y = trace_map (map relu) y + dense5 = callable $ dense _ (Fin 10) + y = dense5 y + y = trace_map2 add y x + + y = trace_map (map relu) y + dense7 = callable $ dense _ (Fin 2) + y = dense7 y + y + +(AsLayer _ _ resnet_incorrect') = resnet_incorrect +w_resnet_bad = (getAt #init resnet_incorrect') (newKey 1) +:t w_resnet_bad +:p (getAt #forward resnet_incorrect') w_resnet_bad (for _. 0.) + + +' ResNet - nested sequential + +' In `resnet_block` below, `x` is used twice, but that is safe because it's the +' input tracer with no network parameters yet. + +' We'd like to just use `_` for the return type to be `_`. +' Unfortunately that currently leads to a "leaked type variable `a`" error. + +def resnet_block (a:Type) : Layer (a=>Float) (a=>Float) = trace \x. + dense1 = callable $ dense a (Fin 25) + y = dense1 x + y = trace_map (map relu) y + dense3 = callable $ dense _ a + y = dense3 y + y = trace_map2 add y x + y + +resnet = trace \x. + dense1 = callable $ dense (Fin 2) (Fin 10) + x = dense1 x + x = trace_map (map relu) x + block3 = callable $ resnet_block _ + x = block3 x + x = trace_map (map relu) x + dense7 = callable $ dense _ (Fin 2) + x = dense7 x + x + +(AsLayer __ resnet') = resnet +w_resnet = (getAt #init resnet') (newKey 1) +:t w_resnet +:p (getAt #forward resnet') w_resnet (for _. 0.) + + +' ## Training + +' Train a multiclass classifier with minibatch SGD +' `minibatch * minibatches = batch` + +def split (x: batch=>v) : minibatches=>minibatch=>v = + for b j. x.((ordinal (b,j))@batch) + +def doTrain (ps:Type) (vs:(VSpace ps)) ?=> + (model': {forward:(ps->a->(b=>Float)) & init:(Key->ps)}) + (x: batch=>a) + (y: batch=>b) + (epochs : Type) + (minibatch : Type) + (minibatches : Type) : + (epochs => (a -> (b=>Float)) & epochs => Float) = + xs : minibatches => minibatch => a = split x + ys : minibatches => minibatch => b = split y + unzip $ withState ((getAt #init model') $ newKey 1) $ \params. + for _ : epochs. + loss = sum $ for b : minibatches. + (loss, gradfn) = vjp (\params. + -sum for j. + logits = (getAt #forward model') params xs.b.j + (logsoftmax logits).(ys.b.j)) (get params) + gparams = gradfn 1.0 + params := (get params) - scaleVec (0.05 / (IToF 100)) gparams + loss + ((\params x. (getAt #forward model') params x) (get params), loss) + +def trainClass (model: Layer a (b=>Float)) + (x: batch=>a) + (y: batch=>b) + (epochs : Type) + (minibatch : Type) + (minibatches : Type) : + (epochs => (a -> (b=>Float)) & epochs => Float) = + (AsLayer ps vs model') = model + %instance vsHint = vs + doTrain ps model' x y epochs minibatch minibatches + + +' Sextant classification dataset. + +[k1, k2] = splitKey $ newKey 1 +x1 : Fin 400 => Float = arb k1 +x2 : Fin 400 => Float = arb k2 +y = for i. case ((x1.i * x1.i * x1.i - 3. * x1.i * x2.i * x2.i) > 0.) of + True -> 1 + False -> 0 +xs = for i. [x1.i, x2.i] + +import plot +:html showPlot $ xycPlot x1 x2 $ for i. IToF y.i + +' Train classifier on this dataset. + +-- model = mlp +model = resnet + +(all_param_models, losses) = trainClass model xs (for i. (y.i @ (Fin 2))) (Fin 3000) (Fin 50) (Fin 8) + +' Classification landscape, evolving over training time. Colour denotes softmax. + +span = linspace (Fin 15) (-1.0) (1.0) +tests = for h : (Fin 100). for i . for j. + r = softmax $ all_param_models.((ordinal h * 30)@_) [span.j, -span.i] + [r.(1@_), 0.5*r.(1@_), r.(0@_)] + +:html imseqshow tests