Skip to content
This repository was archived by the owner on Sep 28, 2024. It is now read-only.

Commit 7205199

Browse files
committed
doc fixes again
1 parent d9b8c5b commit 7205199

File tree

2 files changed

+56
-56
lines changed

2 files changed

+56
-56
lines changed

src/deeponet.jl

Lines changed: 32 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""
2-
DeepONet(; branch = (64, 32, 32, 16), trunk = (1, 8, 8, 16),
3-
branch_activation = identity, trunk_activation = identity)
2+
DeepONet(; branch = (64, 32, 32, 16), trunk = (1, 8, 8, 16),
3+
branch_activation = identity, trunk_activation = identity)
44
55
Constructs a DeepONet composed of Dense layers. Make sure the last node of `branch` and
66
`trunk` are same.
@@ -44,25 +44,25 @@ Trunk net :
4444
)
4545
```
4646
"""
47-
function DeepONet(; branch = (64, 32, 32, 16), trunk = (1, 8, 8, 16),
48-
branch_activation = identity, trunk_activation = identity)
47+
function DeepONet(; branch=(64, 32, 32, 16), trunk=(1, 8, 8, 16),
48+
branch_activation=identity, trunk_activation=identity)
4949

50-
# checks for last dimension size
51-
@argcheck branch[end] == trunk[end] "Branch and Trunk net must share the same amount of \
52-
nodes in the last layer. Otherwise Σᵢ bᵢⱼ tᵢₖ won't \
53-
work."
50+
# checks for last dimension size
51+
@argcheck branch[end]==trunk[end] "Branch and Trunk net must share the same amount of \
52+
nodes in the last layer. Otherwise Σᵢ bᵢⱼ tᵢₖ won't \
53+
work."
5454

55-
branch_net = Chain([Dense(branch[i] => branch[i+1], branch_activation)
56-
for i in 1:(length(branch)-1)]...)
55+
branch_net = Chain([Dense(branch[i] => branch[i + 1], branch_activation)
56+
for i in 1:(length(branch) - 1)]...)
5757

58-
trunk_net = Chain([Dense(trunk[i] => trunk[i+1], trunk_activation)
59-
for i in 1:(length(trunk)-1)]...)
58+
trunk_net = Chain([Dense(trunk[i] => trunk[i + 1], trunk_activation)
59+
for i in 1:(length(trunk) - 1)]...)
6060

61-
return DeepONet(branch_net, trunk_net)
61+
return DeepONet(branch_net, trunk_net)
6262
end
6363

6464
"""
65-
DeepONet(branch, trunk)
65+
DeepONet(branch, trunk)
6666
6767
Constructs a DeepONet from a `branch` and `trunk` architectures. Make sure that both the
6868
nets output should have the same first dimension.
@@ -107,28 +107,28 @@ Trunk net :
107107
```
108108
"""
109109
function DeepONet(branch::L1, trunk::L2) where {L1, L2}
110-
return @compact(; branch, trunk, dispatch = :DeepONet) do (u, y) # ::AbstractArray{<:Real, M} where {M}
111-
t = trunk(y) # p x N x nb
112-
b = branch(u) # p x nb
113-
114-
# checks for last dimension size
115-
@argcheck size(t, 1) == size(b, 1) "Branch and Trunk net must share the same amount \
116-
of nodes in the last layer. Otherwise Σᵢ bᵢⱼ tᵢₖ \
117-
won't work."
118-
119-
tᵀ = permutedims(t, (2, 1, 3)) # N x p x nb
120-
b_ = permutedims(reshape(b, size(b)..., 1), (1, 3, 2)) # p x 1 x nb
121-
G = batched_mul(tᵀ, b_) # N x 1 X nb
122-
@return dropdims(G; dims = 2)
123-
end
110+
return @compact(; branch, trunk, dispatch=:DeepONet) do (u, y) # ::AbstractArray{<:Real, M} where {M}
111+
t = trunk(y) # p x N x nb
112+
b = branch(u) # p x nb
113+
114+
# checks for last dimension size
115+
@argcheck size(t, 1)==size(b, 1) "Branch and Trunk net must share the same amount \
116+
of nodes in the last layer. Otherwise Σᵢ bᵢⱼ tᵢₖ \
117+
won't work."
118+
119+
tᵀ = permutedims(t, (2, 1, 3)) # N x p x nb
120+
b_ = permutedims(reshape(b, size(b)..., 1), (1, 3, 2)) # p x 1 x nb
121+
G = batched_mul(tᵀ, b_) # N x 1 X nb
122+
@return dropdims(G; dims=2)
123+
end
124124
end
125125

126126
function Base.show(io::IO, model::Lux.CompactLuxLayer{:DeepONet})
127-
Lux._print_wrapper_model(io, "Branch net :\n", model.layers.branch)
128-
print(io, "\n \n")
129-
Lux._print_wrapper_model(io, "Trunk net :\n", model.layers.trunk)
127+
Lux._print_wrapper_model(io, "Branch net :\n", model.layers.branch)
128+
print(io, "\n \n")
129+
Lux._print_wrapper_model(io, "Trunk net :\n", model.layers.trunk)
130130
end
131131

132132
function Base.show(io::IO, ::MIME"text/plain", x::CompactLuxLayer{:DeepONet})
133-
show(io, x)
133+
show(io, x)
134134
end

test/deeponet_tests.jl

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,35 @@
1-
@testitem "DeepONet" setup = [SharedTestSetup] begin
2-
@testset "BACKEND: $(mode)" for (mode, aType, dev, ongpu) in MODES
3-
rng_ = get_stable_rng()
1+
@testitem "DeepONet" setup=[SharedTestSetup] begin
2+
@testset "BACKEND: $(mode)" for (mode, aType, dev, ongpu) in MODES
3+
rng_ = get_stable_rng()
44

5-
u = rand(64, 5) |> aType # sensor_points x nb
6-
y = rand(1, 10, 5) |> aType # ndims x N x nb
7-
out_size = (10, 5)
5+
u = rand(64, 5) |> aType # sensor_points x nb
6+
y = rand(1, 10, 5) |> aType # ndims x N x nb
7+
out_size = (10, 5)
88

9-
don_ = DeepONet(; branch = (64, 32, 32, 16), trunk = (1, 8, 8, 16))
9+
don_ = DeepONet(; branch=(64, 32, 32, 16), trunk=(1, 8, 8, 16))
1010

11-
ps, st = Lux.setup(rng_, don_) |> dev
11+
ps, st = Lux.setup(rng_, don_) |> dev
1212

13-
@inferred don_((u, y), ps, st)
14-
@jet don_((u, y), ps, st)
13+
@inferred don_((u, y), ps, st)
14+
@jet don_((u, y), ps, st)
1515

16-
pred = first(don_((u, y), ps, st))
17-
@test size(pred) == out_size
16+
pred = first(don_((u, y), ps, st))
17+
@test size(pred) == out_size
1818

19-
don_ = DeepONet(Chain(Dense(64 => 32), Dense(32 => 32), Dense(32 => 16)),
20-
Chain(Dense(1 => 8), Dense(8 => 8), Dense(8 => 16)))
19+
don_ = DeepONet(Chain(Dense(64 => 32), Dense(32 => 32), Dense(32 => 16)),
20+
Chain(Dense(1 => 8), Dense(8 => 8), Dense(8 => 16)))
2121

22-
ps, st = Lux.setup(rng_, don_) |> dev
22+
ps, st = Lux.setup(rng_, don_) |> dev
2323

24-
@inferred don_((u, y), ps, st)
25-
@jet don_((u, y), ps, st)
24+
@inferred don_((u, y), ps, st)
25+
@jet don_((u, y), ps, st)
2626

27-
pred = first(don_((u, y), ps, st))
28-
@test size(pred) == out_size
27+
pred = first(don_((u, y), ps, st))
28+
@test size(pred) == out_size
2929

30-
don_ = DeepONet(Chain(Dense(64 => 32), Dense(32 => 32), Dense(32 => 20)),
31-
Chain(Dense(1 => 8), Dense(8 => 8), Dense(8 => 16)))
32-
ps, st = Lux.setup(rng_, don_) |> dev
33-
@test_throws ArgumentError don_((u, y), ps, st)
34-
end
30+
don_ = DeepONet(Chain(Dense(64 => 32), Dense(32 => 32), Dense(32 => 20)),
31+
Chain(Dense(1 => 8), Dense(8 => 8), Dense(8 => 16)))
32+
ps, st = Lux.setup(rng_, don_) |> dev
33+
@test_throws ArgumentError don_((u, y), ps, st)
34+
end
3535
end

0 commit comments

Comments
 (0)