Skip to content
This repository was archived by the owner on Sep 28, 2024. It is now read-only.

Commit d5777de

Browse files
committed
Parameterized types on DeepONet and adjoint to concrete type
1 parent 1ec89bf commit d5777de

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

src/DeepONet.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -77,9 +77,9 @@ branch net: (Chain(Dense(2, 128), Dense(128, 64), Dense(64, 72)))
7777
Trunk net: (Chain(Dense(1, 24), Dense(24, 72)))
7878
```
7979
"""
80-
struct DeepONet
81-
branch_net::Flux.Chain
82-
trunk_net::Flux.Chain
80+
struct DeepONet{T1, T2}
81+
branch_net::T1
82+
trunk_net::T2
8383
end
8484

8585
# Declare the function that assigns Weights and biases to the layer
@@ -99,7 +99,7 @@ function DeepONet(architecture_branch::Tuple, architecture_trunk::Tuple,
9999
trunk_net = construct_subnet(architecture_trunk, act_trunk;
100100
init=init_trunk, bias=bias_trunk)
101101

102-
return DeepONet(branch_net, trunk_net)
102+
return DeepONet{typeof(branch_net),typeof(trunk_net)}(branch_net, trunk_net)
103103
end
104104

105105
Flux.@functor DeepONet
@@ -116,7 +116,7 @@ function (a::DeepONet)(x::AbstractArray, y::AbstractVecOrMat)
116116
However, we perform the transformations by the NNs always in the first dim
117117
so we need to adjust (i.e. transpose) one of the inputs,
118118
which we do on the branch input here =#
119-
return branch(x)' * trunk(y)
119+
return Array(branch(x)') * trunk(y)
120120
end
121121

122122
# Sensors stay the same and shouldn't be batched

0 commit comments

Comments
 (0)