487
487
Parallel (connection, layers... ) = Parallel (connection, layers)
488
488
function Parallel (connection; kw... )
489
489
layers = NamedTuple (kw)
490
- if :layers in Base . keys (layers) || :connection in Base . keys (layers)
490
+ if :layers in keys (layers) || :connection in keys (layers)
491
491
throw (ArgumentError (" a Parallel layer cannot have a named sub-layer called `connection` or `layers`" ))
492
492
end
493
493
isempty (layers) && return Parallel (connection, ())
@@ -510,16 +510,100 @@ end
510
510
Base. getindex (m:: Parallel , i) = m. layers[i]
511
511
Base. getindex (m:: Parallel , i:: AbstractVector ) = Parallel (m. connection, m. layers[i])
512
512
Base. getindex (m:: Parallel{<:Any, <:NamedTuple} , i:: AbstractVector ) =
513
- Parallel (m. connection, NamedTuple {Base. keys(m)[i]} (Tuple (m. layers)[i]))
513
+ Parallel (m. connection, NamedTuple {keys(m)[i]} (Tuple (m. layers)[i]))
514
514
515
- Base. keys (m:: Parallel ) = Base . keys (getfield (m, :layers ))
515
+ Base. keys (m:: Parallel ) = keys (getfield (m, :layers ))
516
516
517
517
function Base. show (io:: IO , m:: Parallel )
518
518
print (io, " Parallel(" , m. connection, " , " )
519
519
_show_layers (io, m. layers)
520
520
print (io, " )" )
521
521
end
522
522
523
+ """
524
+ PairwiseFusion(connection, layers...)
525
+
526
+ ```
527
+ x1 --> layer1 --> y1
528
+ |
529
+ |--> connection --> layer2 --> y2
530
+ | |
531
+ x2 |--> connection --> layer3 --> y3
532
+ | |
533
+ x3 |--> connection --> y4
534
+ |
535
+ x4
536
+ ```
537
+
538
+ ## Arguments
539
+
540
+ - `connection`: Takes 2 inputs and combines them
541
+ - `layers`: The layers whose outputs are combined
542
+
543
+ ## Inputs
544
+
545
+ This layer behaves differently based on input type:
546
+
547
+ 1. Input `x` is a tuple/vector of length `N`. Then `layers` must be a tuple of length `N`. The computation is as follows:
548
+
549
+ ```julia
550
+ y = x[1]
551
+ for i in 1:N
552
+ y = connection(x[i], layers[i](y))
553
+ end
554
+ ```
555
+
556
+ 2. Any other kind of input:
557
+
558
+ ```julia
559
+ y = x
560
+ for i in 1:N
561
+ y = connection(x, layers[i](y))
562
+ end
563
+ ```
564
+
565
+ ## Returns
566
+
567
+ `PairwiseFusion` returns a tuple of length `N` with the output of each fusion ((`y1`, `y2`, ..., `yN`) in the example above).
568
+ """
569
+ struct PairwiseFusion{F, T<: Union{Tuple, NamedTuple} }
570
+ connection:: F
571
+ layers:: T
572
+ end
573
+
574
+ PairwiseFusion (connection, layers... ) = PairwiseFusion (connection, layers)
575
+ function PairwiseFusion (connection; kw... )
576
+ layers = NamedTuple (kw)
577
+ if :layers in keys (layers) || :connection in keys (layers)
578
+ throw (ArgumentError (" a Parallel layer cannot have a named sub-layer called `connection` or `layers`" ))
579
+ end
580
+ isempty (layers) && return Parallel (connection, ())
581
+ return PairwiseFusion (connection, layers)
582
+ end
583
+
584
+ function (m:: PairwiseFusion )(x:: T ) where {T}
585
+ getinput (i) = T <: Union{Tuple, Vector} ? x[i] : x
586
+ nx = length (x)
587
+ nlayers = length (m. layers)
588
+ if nx != nlayers
589
+ throw (ArgumentError (" PairwiseFusion with $nlayers layers takes $nlayers inputs, but got $nx inputs" ))
590
+ end
591
+ outputs = [m. layers[1 ](getinput (1 ))]
592
+ for i in 2 : nlayers
593
+ push! (outputs, m. layers[i](m. connection (getinput (i), outputs[i - 1 ])))
594
+ end
595
+ return outputs
596
+ end
597
+
598
+ @functor PairwiseFusion
599
+
600
+ Base. getindex (m:: PairwiseFusion , i) = m. layers[i]
601
+ Base. getindex (m:: PairwiseFusion , i:: AbstractVector ) = PairwiseFusion (m. connection, m. layers[i])
602
+ Base. getindex (m:: PairwiseFusion{<:Any, <:NamedTuple} , i:: AbstractVector ) =
603
+ PairwiseFusion (m. connection, NamedTuple {keys(m)[i]} (Tuple (m. layers)[i]))
604
+
605
+ Base. keys (m:: PairwiseFusion ) = keys (getfield (m, :layers ))
606
+
523
607
"""
524
608
Embedding(in => out; init=randn)
525
609
0 commit comments