564
564
565
565
## Returns
566
566
567
- `PairwiseFusion` returns a tuple of length `N` with the output of each fusion ((`y1`, `y2`, ..., `yN`) in the example above).
567
+ A tuple of length `N` with the output of each fusion ((`y1`, `y2`, ..., `yN`) in the example above).
568
568
"""
569
569
struct PairwiseFusion{F, T <: NamedTuple }
570
570
connection:: F
@@ -576,12 +576,17 @@ function PairwiseFusion(connection, layers...)
576
576
return PairwiseFusion (connection, NamedTuple {names} (layers))
577
577
end
578
578
579
- function (m:: PairwiseFusion )(x:: T ) where {T}
580
- lx = length (x)
581
- N = length (m. layers)
579
+ function _pairwise_check (lx, N, T)
582
580
if T <: Tuple && lx != N
583
581
throw (ArgumentError (" PairwiseFusion with $N sub-layers can take one input or $N inputs, but got $lx inputs" ))
584
582
end
583
+ end
584
+ ChainRulesCore. @non_differentiable _pairwise_check (lx, N, T)
585
+
586
+ function (m:: PairwiseFusion )(x:: T ) where {T}
587
+ lx = length (x)
588
+ N = length (m. layers)
589
+ _pairwise_check (lx, N, T)
585
590
applypairwisefusion (m. layers, m. connection, x)
586
591
end
587
592
@@ -590,10 +595,12 @@ end
590
595
y_symbols = [gensym () for _ in 1 : (N + 1 )]
591
596
getinput (i) = T <: Tuple ? :(x[$ i]) : :x
592
597
calls = [:($ (y_symbols[N + 1 ]) = $ (getinput (1 )))]
593
- append! (calls,
594
- [:($ (y_symbols[i]) = layers[$ i]($ (y_symbols[N + 1 ]));
595
- $ (y_symbols[N + 1 ]) = connection ($ (y_symbols[i]), $ (getinput (i + 1 ))))
596
- for i in 1 : N - 1 ])
598
+ for i in 1 : N - 1
599
+ push! (calls, quote
600
+ $ (y_symbols[i]) = layers[$ i]($ (y_symbols[N + 1 ]))
601
+ $ (y_symbols[N + 1 ]) = connection ($ (y_symbols[i]), $ (getinput (i + 1 )))
602
+ end )
603
+ end
597
604
push! (calls, :($ (y_symbols[N]) = layers[$ N]($ (y_symbols[N + 1 ]))))
598
605
push! (calls, :(return tuple ($ (Tuple (y_symbols[1 : N])... ))))
599
606
return Expr (:block , calls... )
0 commit comments