@@ -21,27 +21,31 @@ operators", doi: https://arxiv.org/abs/1910.03193
2121## Example
2222
2323```jldoctest
24- deeponet = DeepONet(; branch=(64, 32, 32, 16), trunk=(1, 8, 8, 16))
25-
26- # output
27-
28- Branch net :
29- (
30- Chain(
24+ julia> deeponet = DeepONet(; branch=(64, 32, 32, 16), trunk=(1, 8, 8, 16))
25+ @compact(
26+ branch = Chain(
3127 layer_1 = Dense(64 => 32), # 2_080 parameters
3228 layer_2 = Dense(32 => 32), # 1_056 parameters
3329 layer_3 = Dense(32 => 16), # 528 parameters
3430 ),
35- )
36-
37- Trunk net :
38- (
39- Chain(
31+ trunk = Chain(
4032 layer_1 = Dense(1 => 8), # 16 parameters
4133 layer_2 = Dense(8 => 8), # 72 parameters
4234 layer_3 = Dense(8 => 16), # 144 parameters
4335 ),
44- )
36+ ) do (u, y)
37+ t = trunk(y)
38+ b = branch(u)
39+ @argcheck ndims(t) == ndims(b) + 1 || ndims(t) == ndims(b)
40+ @argcheck size(t, 1) == size(b, 1) "Branch and Trunk net must share the same amount of nodes in the last layer. Otherwise Σᵢ bᵢⱼ tᵢₖ won't work."
41+ b_ = if ndims(t) == ndims(b)
42+ b
43+ else
44+ reshape(b, size(b, 1), 1, (size(b))[2:end]...)
45+ end
46+ return dropdims(sum(t .* b_; dims = 1); dims = 1)
47+ end # Total: 3_896 parameters,
48+ # plus 0 states.
4549```
4650"""
4751function DeepONet (; branch= (64 , 32 , 32 , 16 ), trunk= (1 , 8 , 8 , 16 ),
@@ -81,54 +85,48 @@ operators", doi: https://arxiv.org/abs/1910.03193
8185## Example
8286
8387```jldoctest
84- branch_net = Chain(Dense(64 => 32), Dense(32 => 32), Dense(32 => 16));
85- trunk_net = Chain(Dense(1 => 8), Dense(8 => 8), Dense(8 => 16));
86- don_ = DeepONet(branch_net, trunk_net)
88+ julia> branch_net = Chain(Dense(64 => 32), Dense(32 => 32), Dense(32 => 16));
8789
88- # output
90+ julia> trunk_net = Chain(Dense(1 => 8), Dense(8 => 8), Dense(8 => 16));
8991
90- Branch net :
91- (
92- Chain(
92+ julia> deeponet = DeepONet(branch_net, trunk_net)
93+ @compact (
94+ branch = Chain(
9395 layer_1 = Dense(64 => 32), # 2_080 parameters
9496 layer_2 = Dense(32 => 32), # 1_056 parameters
9597 layer_3 = Dense(32 => 16), # 528 parameters
9698 ),
97- )
98-
99- Trunk net :
100- (
101- Chain(
99+ trunk = Chain(
102100 layer_1 = Dense(1 => 8), # 16 parameters
103101 layer_2 = Dense(8 => 8), # 72 parameters
104102 layer_3 = Dense(8 => 16), # 144 parameters
105103 ),
106- )
104+ ) do (u, y)
105+ t = trunk(y)
106+ b = branch(u)
107+ @argcheck ndims(t) == ndims(b) + 1 || ndims(t) == ndims(b)
108+ @argcheck size(t, 1) == size(b, 1) "Branch and Trunk net must share the same amount of nodes in the last layer. Otherwise Σᵢ bᵢⱼ tᵢₖ won't work."
109+ b_ = if ndims(t) == ndims(b)
110+ b
111+ else
112+ reshape(b, size(b, 1), 1, (size(b))[2:end]...)
113+ end
114+ return dropdims(sum(t .* b_; dims = 1); dims = 1)
115+ end # Total: 3_896 parameters,
116+ # plus 0 states.
107117```
108118"""
109119function DeepONet (branch:: L1 , trunk:: L2 ) where {L1, L2}
110- return @compact (; branch, trunk, dispatch= :DeepONet ) do (u, y) # ::AbstractArray{<:Real, M} where {M}
111- t = trunk (y) # p x N x nb
112- b = branch (u) # p x nb
113-
114- # checks for last dimension size
115- @argcheck size (t, 1 )== size (b, 1 ) " Branch and Trunk net must share the same amount \
116- of nodes in the last layer. Otherwise Σᵢ bᵢⱼ tᵢₖ \
117- won't work."
118-
119- tᵀ = permutedims (t, (2 , 1 , 3 )) # N x p x nb
120- b_ = permutedims (reshape (b, size (b)... , 1 ), (1 , 3 , 2 )) # p x 1 x nb
121- G = batched_mul (tᵀ, b_) # N x 1 X nb
122- @return dropdims (G; dims= 2 )
123- end
124- end
120+ return @compact (; branch, trunk, dispatch= :DeepONet ) do (u, y)
121+ t = trunk (y) # p x N x nb...
122+ b = branch (u) # p x nb...
125123
126- function Base. show (io:: IO , model:: Lux.CompactLuxLayer{:DeepONet} )
127- Lux. _print_wrapper_model (io, " Branch net :\n " , model. layers. branch)
128- print (io, " \n \n " )
129- Lux. _print_wrapper_model (io, " Trunk net :\n " , model. layers. trunk)
130- end
124+ @argcheck ndims (t) == ndims (b) + 1 || ndims (t) == ndims (b)
125+ @argcheck size (t, 1 )== size (b, 1 ) " Branch and Trunk net must share the same \
126+ amount of nodes in the last layer. Otherwise \
127+ Σᵢ bᵢⱼ tᵢₖ won't work."
131128
132- function Base. show (io:: IO , :: MIME"text/plain" , x:: CompactLuxLayer{:DeepONet} )
133- show (io, x)
129+ b_ = ndims (t) == ndims (b) ? b : reshape (b, size (b, 1 ), 1 , size (b)[2 : end ]. .. )
130+ @return dropdims (sum (t .* b_; dims= 1 ); dims= 1 )
131+ end
134132end
0 commit comments