Skip to content
This repository was archived by the owner on Sep 28, 2024. It is now read-only.

Commit 38d9b7c

Browse files
authored
Merge pull request #45 from Abhishek-1Bhatt/type-inference
Improving Type Inference on DeepONet
2 parents b900ba9 + 0a0184a commit 38d9b7c

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

src/DeepONet.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -77,9 +77,9 @@ branch net: (Chain(Dense(2, 128), Dense(128, 64), Dense(64, 72)))
7777
Trunk net: (Chain(Dense(1, 24), Dense(24, 72)))
7878
```
7979
"""
80-
struct DeepONet
81-
branch_net::Flux.Chain
82-
trunk_net::Flux.Chain
80+
struct DeepONet{T1, T2}
81+
branch_net::T1
82+
trunk_net::T2
8383
end
8484

8585
# Declare the function that assigns Weights and biases to the layer
@@ -99,7 +99,7 @@ function DeepONet(architecture_branch::Tuple, architecture_trunk::Tuple,
9999
trunk_net = construct_subnet(architecture_trunk, act_trunk;
100100
init=init_trunk, bias=bias_trunk)
101101

102-
return DeepONet(branch_net, trunk_net)
102+
return DeepONet{typeof(branch_net),typeof(trunk_net)}(branch_net, trunk_net)
103103
end
104104

105105
Flux.@functor DeepONet
@@ -116,7 +116,7 @@ function (a::DeepONet)(x::AbstractArray, y::AbstractVecOrMat)
116116
However, we perform the transformations by the NNs always in the first dim
117117
so we need to adjust (i.e. transpose) one of the inputs,
118118
which we do on the branch input here =#
119-
return branch(x)' * trunk(y)
119+
return Array(branch(x)') * trunk(y)
120120
end
121121

122122
# Sensors stay the same and shouldn't be batched

0 commit comments

Comments
 (0)