38
38
39
39
Chain (xs... ) = Chain (xs)
40
40
function Chain (; kw... )
41
- :layers in Base . keys (kw) && throw (ArgumentError (" a Chain cannot have a named layer called `layers`" ))
41
+ :layers in keys (kw) && throw (ArgumentError (" a Chain cannot have a named layer called `layers`" ))
42
42
isempty (kw) && return Chain (())
43
43
Chain (values (kw))
44
44
end
67
67
68
68
Base. getindex (c:: Chain , i:: AbstractArray ) = Chain (c. layers[i])
69
69
Base. getindex (c:: Chain{<:NamedTuple} , i:: AbstractArray ) =
70
- Chain (NamedTuple {Base. keys(c)[i]} (Tuple (c. layers)[i]))
70
+ Chain (NamedTuple {keys(c)[i]} (Tuple (c. layers)[i]))
71
71
function Base. show (io:: IO , c:: Chain )
72
72
print (io, " Chain(" )
73
73
_show_layers (io, c. layers)
487
487
Parallel (connection, layers... ) = Parallel (connection, layers)
488
488
function Parallel (connection; kw... )
489
489
layers = NamedTuple (kw)
490
- if :layers in Base . keys (layers) || :connection in Base . keys (layers)
490
+ if :layers in keys (layers) || :connection in keys (layers)
491
491
throw (ArgumentError (" a Parallel layer cannot have a named sub-layer called `connection` or `layers`" ))
492
492
end
493
493
isempty (layers) && return Parallel (connection, ())
@@ -498,28 +498,138 @@ end
498
498
499
499
(m:: Parallel )(x) = m. connection (map (f -> f (x), Tuple (m. layers))... )
500
500
(m:: Parallel )(xs:: Tuple ) = m (xs... )
501
- function (m:: Parallel )(xs... )
502
- nl = length (m. layers)
503
- nx = length (xs)
504
- if nl != nx
501
+
502
+ function _parallel_check (layers, xs)
503
+ nl = length (layers)
504
+ nx = length (xs)
505
+ if (nl != nx)
505
506
throw (ArgumentError (" Parallel with $nl sub-layers can take one input or $nl inputs, but got $nx inputs" ))
506
507
end
508
+ end
509
+ ChainRulesCore. @non_differentiable _parallel_check (nl, nx)
510
+
511
+ function (m:: Parallel )(xs... )
512
+ _parallel_check (m. layers, xs)
507
513
m. connection (map (|> , xs, Tuple (m. layers))... )
508
514
end
509
515
510
516
Base. getindex (m:: Parallel , i) = m. layers[i]
511
517
Base. getindex (m:: Parallel , i:: AbstractVector ) = Parallel (m. connection, m. layers[i])
512
518
Base. getindex (m:: Parallel{<:Any, <:NamedTuple} , i:: AbstractVector ) =
513
- Parallel (m. connection, NamedTuple {Base. keys(m)[i]} (Tuple (m. layers)[i]))
519
+ Parallel (m. connection, NamedTuple {keys(m)[i]} (Tuple (m. layers)[i]))
514
520
515
- Base. keys (m:: Parallel ) = Base . keys (getfield (m, :layers ))
521
+ Base. keys (m:: Parallel ) = keys (getfield (m, :layers ))
516
522
517
523
function Base. show (io:: IO , m:: Parallel )
518
524
print (io, " Parallel(" , m. connection, " , " )
519
525
_show_layers (io, m. layers)
520
526
print (io, " )" )
521
527
end
522
528
529
+ """
530
+ PairwiseFusion(connection, layers...)
531
+
532
+ ## Arguments
533
+
534
+ - `connection`: A function taking 2 inputs and combining them into a single output
535
+ - `layers`: The layers whose outputs are combined
536
+
537
+ ## Inputs
538
+
539
+ This layer behaves differently based on input type:
540
+
541
+ 1. If input `x` is a tuple of length N (or the input is `xs` with N `x`'s), matching the number of `layers`,
542
+ then each layer receives a new input `x[i]` combined with the previous output `y[i-1]` using `connection`.
543
+ Thus `(y1, y2, y3) = PairwiseFusion(connection, layer1, layer2, layer3)((x1, x2, x3))`
544
+ may be drawn as:
545
+ ```
546
+ x1 → layer1 → y1 ↘
547
+ connection → layer2 → y2 ↘
548
+ x2 ↗ connection → layer3 → y3
549
+ x3 ↗
550
+ ```
551
+ ... or written as:
552
+ ```julia
553
+ y1 = layer1(x1)
554
+ y2 = layer2(connection(x2, y1))
555
+ y3 = layer3(connection(x3, y2))
556
+ ```
557
+
558
+ 2. With just one input, each layer receives the same `x` combined with the previous output.
559
+ Thus `y = PairwiseFusion(connection, layers...)(x)` obeys:
560
+
561
+ ```julia
562
+ y[1] == layers[1](x)
563
+ for i in 2:length(layers)
564
+ y[i] == connection(x, layers[i](y[i-1]))
565
+ end
566
+ ```
567
+
568
+ ## Returns
569
+
570
+ A tuple of length N with the output of each fusion ((`y1`, `y2`, ..., `yN`) in the example above).
571
+ """
572
+ struct PairwiseFusion{F, T<: Union{Tuple, NamedTuple} }
573
+ connection:: F
574
+ layers:: T
575
+ end
576
+
577
+ PairwiseFusion (connection, layers... ) = PairwiseFusion (connection, layers)
578
+ function PairwiseFusion (connection; kw... )
579
+ layers = NamedTuple (kw)
580
+ if :layers in keys (layers) || :connection in keys (layers)
581
+ throw (ArgumentError (" a PairwiseFusion layer cannot have a named sub-layer called `connection` or `layers`" ))
582
+ end
583
+ isempty (layers) && return PairwiseFusion (connection, ())
584
+ PairwiseFusion (connection, layers)
585
+ end
586
+
587
+ function _pairwise_check (x, layers, T)
588
+ lx = length (x)
589
+ N = length (layers)
590
+ if T <: Tuple && lx != N
591
+ throw (ArgumentError (" PairwiseFusion with $N sub-layers can take one input or $N inputs, but got $lx inputs" ))
592
+ end
593
+ end
594
+ ChainRulesCore. @non_differentiable _pairwise_check (lx, N, T)
595
+
596
+ function (m:: PairwiseFusion )(x:: T ) where {T}
597
+ _pairwise_check (x, m. layers, T)
598
+ applypairwisefusion (m. layers, m. connection, x)
599
+ end
600
+ (m:: PairwiseFusion )(xs... ) = m (xs)
601
+
602
+ @generated function applypairwisefusion (layers:: Tuple{Vararg{<:Any,N}} , connection, x:: T ) where {N, T}
603
+ y_symbols = [gensym () for _ in 1 : (N + 1 )]
604
+ getinput (i) = T <: Tuple ? :(x[$ i]) : :x
605
+ calls = [:($ (y_symbols[N + 1 ]) = $ (getinput (1 )))]
606
+ for i in 1 : N - 1
607
+ push! (calls, quote
608
+ $ (y_symbols[i]) = layers[$ i]($ (y_symbols[N + 1 ]))
609
+ $ (y_symbols[N + 1 ]) = connection ($ (y_symbols[i]), $ (getinput (i + 1 )))
610
+ end )
611
+ end
612
+ push! (calls, :($ (y_symbols[N]) = layers[$ N]($ (y_symbols[N + 1 ]))))
613
+ push! (calls, :(return tuple ($ (Tuple (y_symbols[1 : N])... ))))
614
+ return Expr (:block , calls... )
615
+ end
616
+ applypairwisefusion (layers:: NamedTuple , connection, x) = applypairwisefusion (Tuple (layers), connection, x)
617
+
618
+ @functor PairwiseFusion
619
+
620
+ Base. getindex (m:: PairwiseFusion , i) = m. layers[i]
621
+ Base. getindex (m:: PairwiseFusion , i:: AbstractVector ) = PairwiseFusion (m. connection, m. layers[i])
622
+ Base. getindex (m:: PairwiseFusion{<:Any, <:NamedTuple} , i:: AbstractVector ) =
623
+ PairwiseFusion (m. connection, NamedTuple {keys(m)[i]} (Tuple (m. layers)[i]))
624
+
625
+ Base. keys (m:: PairwiseFusion ) = keys (getfield (m, :layers ))
626
+
627
+ function Base. show (io:: IO , m:: PairwiseFusion )
628
+ print (io, " PairwiseFusion(" , m. connection, " , " )
629
+ _show_layers (io, m. layers)
630
+ print (io, " )" )
631
+ end
632
+
523
633
"""
524
634
Embedding(in => out; init=randn)
525
635
556
666
@functor Embedding
557
667
558
668
Embedding ((in, out):: Pair{<:Integer, <:Integer} ; init = randn32) = Embedding (init (out, in))
559
-
669
+
560
670
(m:: Embedding )(x:: Integer ) = m. weight[:, x]
561
671
(m:: Embedding )(x:: AbstractVector ) = NNlib. gather (m. weight, x)
562
672
(m:: Embedding )(x:: AbstractArray ) = reshape (m (vec (x)), :, size (x)... )
@@ -565,7 +675,7 @@ function (m::Embedding)(x::Union{OneHotVector{T,L}, OneHotMatrix{T,L}}) where {T
565
675
size (m. weight, 2 ) == L || throw (DimensionMismatch (" Matrix column must correspond with OneHot size: $(size (m. weight, 2 )) != $L " ))
566
676
return m (onecold (x))
567
677
end
568
-
678
+
569
679
function Base. show (io:: IO , m:: Embedding )
570
680
print (io, " Embedding(" , size (m. weight, 2 ), " => " , size (m. weight, 1 ), " )" )
571
681
end
0 commit comments