Skip to content
This repository was archived by the owner on Sep 28, 2024. It is now read-only.

Commit af9d09f

Browse files
ayushinavavik-pal
andauthored
feat: support multi-output for deeponet (#15)
* deeponet multi-output fix * test bug fix * format * compat with additional layer * format * test fixes * chore: short kearg syntax --------- Co-authored-by: Avik Pal <[email protected]>
1 parent f3d1252 commit af9d09f

File tree

4 files changed

+143
-51
lines changed

4 files changed

+143
-51
lines changed

src/NeuralOperators.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ using FFTW: FFTW, irfft, rfft
77
using Lux
88
using LuxCore: LuxCore, AbstractExplicitLayer
99
using LuxDeviceUtils: get_device, LuxAMDGPUDevice
10-
using NNlib: NNlib,
10+
using NNlib: NNlib, , batched_adjoint
1111
using Random: Random, AbstractRNG
1212
using Reexport: @reexport
1313

src/deeponet.jl

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ Constructs a DeepONet composed of Dense layers. Make sure the last node of `bran
1111
- `trunk`: Tuple of integers containing the number of nodes in each layer for trunk net
1212
- `branch_activation`: activation function for branch net
1313
- `trunk_activation`: activation function for trunk net
14+
- `additional`: `Lux` network to pass the output of DeepONet, to include additional operations
15+
for embeddings, defaults to `nothing`
1416
1517
## References
1618
@@ -33,8 +35,9 @@ julia> size(first(deeponet((u, y), ps, st)))
3335
(10, 5)
3436
```
3537
"""
36-
function DeepONet(; branch=(64, 32, 32, 16), trunk=(1, 8, 8, 16),
37-
branch_activation=identity, trunk_activation=identity)
38+
function DeepONet(;
39+
branch=(64, 32, 32, 16), trunk=(1, 8, 8, 16), branch_activation=identity,
40+
trunk_activation=identity, additional=nothing)
3841

3942
# checks for last dimension size
4043
@argcheck branch[end]==trunk[end] "Branch and Trunk net must share the same amount of \
@@ -47,7 +50,7 @@ function DeepONet(; branch=(64, 32, 32, 16), trunk=(1, 8, 8, 16),
4750
trunk_net = Chain([Dense(trunk[i] => trunk[i + 1], trunk_activation)
4851
for i in 1:(length(trunk) - 1)]...)
4952

50-
return DeepONet(branch_net, trunk_net)
53+
return DeepONet(branch_net, trunk_net; additional)
5154
end
5255

5356
"""
@@ -61,6 +64,11 @@ nets output should have the same first dimension.
6164
- `branch`: `Lux` network to be used as branch net.
6265
- `trunk`: `Lux` network to be used as trunk net.
6366
67+
## Keyword Arguments
68+
69+
- `additional`: `Lux` network to pass the output of DeepONet, to include additional operations
70+
for embeddings, defaults to `nothing`
71+
6472
## References
6573
6674
[1] Lu Lu, Pengzhan Jin, George Em Karniadakis, "DeepONet: Learning nonlinear operators for
@@ -86,17 +94,15 @@ julia> size(first(deeponet((u, y), ps, st)))
8694
(10, 5)
8795
```
8896
"""
89-
function DeepONet(branch::L1, trunk::L2) where {L1, L2}
90-
return @compact(; branch, trunk, dispatch=:DeepONet) do (u, y)
91-
t = trunk(y) # p x N x nb...
92-
b = branch(u) # p x nb...
97+
function DeepONet(branch::L1, trunk::L2; additional=nothing) where {L1, L2}
98+
return @compact(; branch, trunk, additional, dispatch=:DeepONet) do (u, y)
99+
t = trunk(y) # p x N x nb
100+
b = branch(u) # p x u_size... x nb
93101

94-
@argcheck ndims(t) == ndims(b) + 1 || ndims(t) == ndims(b)
95102
@argcheck size(t, 1)==size(b, 1) "Branch and Trunk net must share the same \
96103
amount of nodes in the last layer. Otherwise \
97104
Σᵢ bᵢⱼ tᵢₖ won't work."
98105

99-
b_ = ndims(t) == ndims(b) ? b : reshape(b, size(b, 1), 1, size(b)[2:end]...)
100-
@return dropdims(sum(t .* b_; dims=1); dims=1)
106+
@return __project(b, t, additional)
101107
end
102108
end

src/utils.jl

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,75 @@ end
99
# FIXME: This is not good for performance but that is okay for now
1010
return stack(*, eachslice(x; dims=3), eachslice(y; dims=3))
1111
end
12+
13+
@inline function __project(b::AbstractArray{T1, 2}, t::AbstractArray{T2, 3},
14+
additional::Nothing) where {T1, T2}
15+
# b : p x nb
16+
# t : p x N x nb
17+
b_ = reshape(b, size(b, 1), 1, size(b, 2)) # p x 1 x nb
18+
return dropdims(sum(b_ .* t; dims=1); dims=1) # N x nb
19+
end
20+
21+
@inline function __project(b::AbstractArray{T1, 3}, t::AbstractArray{T2, 3},
22+
additional::Nothing) where {T1, T2}
23+
# b : p x u x nb
24+
# t : p x N x nb
25+
if size(b, 2) == 1 || size(t, 2) == 1
26+
return sum(b .* t; dims=1) # 1 x N x nb
27+
else
28+
return __batched_mul(batched_adjoint(b), t) # u x N x b
29+
end
30+
end
31+
32+
@inline function __project(b::AbstractArray{T1, N}, t::AbstractArray{T2, 3},
33+
additional::Nothing) where {T1, T2, N}
34+
# b : p x u_size x nb
35+
# t : p x N x nb
36+
u_size = size(b)[2:(end - 1)]
37+
38+
b_ = reshape(b, size(b, 1), u_size..., 1, size(b)[end])
39+
# p x u_size x 1 x nb
40+
41+
t_ = reshape(t, size(t, 1), ones(eltype(u_size), length(u_size))..., size(t)[2:end]...)
42+
# p x (1,1,1...) x N x nb
43+
44+
return dropdims(sum(b_ .* t_; dims=1); dims=1) # u_size x N x nb
45+
end
46+
47+
@inline function __project(
48+
b::AbstractArray{T1, 2}, t::AbstractArray{T2, 3}, additional::T) where {T1, T2, T}
49+
# b : p x nb
50+
# t : p x N x nb
51+
b_ = reshape(b, size(b, 1), 1, size(b, 2)) # p x 1 x nb
52+
return additional(b_ .* t) # p x N x nb => out_dims x N x nb
53+
end
54+
55+
@inline function __project(
56+
b::AbstractArray{T1, 3}, t::AbstractArray{T2, 3}, additional::T) where {T1, T2, T}
57+
# b : p x u x nb
58+
# t : p x N x nb
59+
60+
if size(b, 2) == 1 || size(t, 2) == 1
61+
return additional(b .* t) # p x N x nb => out_dims x N x nb
62+
else
63+
b_ = reshape(b, size(b)[1:2]..., 1, size(b, 3)) # p x u x 1 x nb
64+
t_ = reshape(t, size(t, 1), 1, size(t)[2:end]...) # p x 1 x N x nb
65+
66+
return additional(b_ .* t_) # p x u x N x nb => out_size x N x nb
67+
end
68+
end
69+
70+
@inline function __project(b::AbstractArray{T1, N}, t::AbstractArray{T2, 3},
71+
additional::T) where {T1, T2, N, T}
72+
# b : p x u_size x nb
73+
# t : p x N x nb
74+
u_size = size(b)[2:(end - 1)]
75+
76+
b_ = reshape(b, size(b, 1), u_size..., 1, size(b)[end])
77+
# p x u_size x 1 x nb
78+
79+
t_ = reshape(t, size(t, 1), ones(eltype(u_size), length(u_size))..., size(t)[2:end]...)
80+
# p x (1,1,1...) x N x nb
81+
82+
return additional(b_ .* t_) # p x u_size x N x nb => out_size x N x nb
83+
end

test/deeponet_tests.jl

Lines changed: 54 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -2,49 +2,63 @@
22
@testset "BACKEND: $(mode)" for (mode, aType, dev, ongpu) in MODES
33
rng = StableRNG(12345)
44

5-
u = rand(Float32, 64, 5) |> aType # sensor_points x nb
6-
y = rand(Float32, 1, 10, 5) |> aType # ndims x N x nb
7-
out_size = (10, 5)
8-
9-
deeponet = DeepONet(; branch=(64, 32, 32, 16), trunk=(1, 8, 8, 16))
10-
display(deeponet)
11-
ps, st = Lux.setup(rng, deeponet) |> dev
12-
13-
@inferred deeponet((u, y), ps, st)
14-
@jet deeponet((u, y), ps, st)
15-
16-
pred = first(deeponet((u, y), ps, st))
17-
@test size(pred) == out_size
18-
19-
deeponet = DeepONet(Chain(Dense(64 => 32), Dense(32 => 32), Dense(32 => 16)),
20-
Chain(Dense(1 => 8), Dense(8 => 8), Dense(8 => 16)))
21-
display(deeponet)
22-
ps, st = Lux.setup(rng, deeponet) |> dev
23-
24-
@inferred deeponet((u, y), ps, st)
25-
@jet deeponet((u, y), ps, st)
26-
27-
pred = first(deeponet((u, y), ps, st))
28-
@test size(pred) == out_size
29-
30-
deeponet = DeepONet(Chain(Dense(64 => 32), Dense(32 => 32), Dense(32 => 20)),
31-
Chain(Dense(1 => 8), Dense(8 => 8), Dense(8 => 16)))
32-
display(deeponet)
33-
ps, st = Lux.setup(rng, deeponet) |> dev
34-
@test_throws ArgumentError deeponet((u, y), ps, st)
35-
36-
@testset "higher-dim input #7" begin
37-
u = ones(Float32, 10, 10, 10) |> aType
38-
v = ones(Float32, 1, 10, 10) |> aType
39-
deeponet = DeepONet(; branch=(10, 10, 10), trunk=(1, 10, 10))
40-
display(deeponet)
5+
setups = [
6+
(u_size=(64, 5), y_size=(1, 10, 5), out_size=(10, 5),
7+
branch=(64, 32, 32, 16), trunk=(1, 8, 8, 16), name="Scalar"),
8+
(u_size=(64, 1, 5), y_size=(1, 10, 5), out_size=(1, 10, 5),
9+
branch=(64, 32, 32, 16), trunk=(1, 8, 8, 16), name="Scalar II"),
10+
(u_size=(64, 3, 5), y_size=(4, 10, 5), out_size=(3, 10, 5),
11+
branch=(64, 32, 32, 16), trunk=(4, 8, 8, 16), name="Vector"),
12+
(u_size=(64, 4, 3, 3, 5), y_size=(4, 10, 5), out_size=(4, 3, 3, 10, 5),
13+
branch=(64, 32, 32, 16), trunk=(4, 8, 8, 16), name="Tensor")]
14+
15+
@testset "$(setup.name)" for setup in setups
16+
u = rand(Float32, setup.u_size...) |> aType
17+
y = rand(Float32, setup.y_size...) |> aType
18+
deeponet = DeepONet(; branch=setup.branch, trunk=setup.trunk)
19+
20+
ps, st = Lux.setup(rng, deeponet) |> dev
21+
@inferred first(deeponet((u, y), ps, st))
22+
@jet first(deeponet((u, y), ps, st))
23+
24+
pred = first(deeponet((u, y), ps, st))
25+
@test setup.out_size == size(pred)
26+
end
27+
28+
setups = [
29+
(u_size=(64, 5), y_size=(1, 10, 5), out_size=(4, 10, 5),
30+
branch=(64, 32, 32, 16), trunk=(1, 8, 8, 16),
31+
additional=Dense(16 => 4), name="Scalar"),
32+
(u_size=(64, 1, 5), y_size=(1, 10, 5), out_size=(4, 10, 5),
33+
branch=(64, 32, 32, 16), trunk=(1, 8, 8, 16),
34+
additional=Dense(16 => 4), name="Scalar II"),
35+
(u_size=(64, 3, 5), y_size=(8, 10, 5), out_size=(4, 3, 10, 5),
36+
branch=(64, 32, 32, 16), trunk=(8, 8, 8, 16),
37+
additional=Dense(16 => 4), name="Vector")]
38+
39+
@testset "Additional layer: $(setup.name)" for setup in setups
40+
u = rand(Float32, setup.u_size...) |> aType
41+
y = rand(Float32, setup.y_size...) |> aType
42+
deeponet = DeepONet(;
43+
branch=setup.branch, trunk=setup.trunk, additional=setup.additional)
44+
4145
ps, st = Lux.setup(rng, deeponet) |> dev
46+
@inferred first(deeponet((u, y), ps, st))
47+
@jet first(deeponet((u, y), ps, st))
48+
49+
pred = first(deeponet((u, y), ps, st))
50+
@test setup.out_size == size(pred)
51+
end
52+
53+
@testset "Embedding layer mismatch" begin
54+
u = rand(Float32, 64, 5) |> aType
55+
y = rand(Float32, 1, 10, 5) |> aType
4256

43-
y, st_ = deeponet((u, v), ps, st)
44-
@test size(y) == (10, 10)
57+
deeponet = DeepONet(Chain(Dense(64 => 32), Dense(32 => 32), Dense(32 => 20)),
58+
Chain(Dense(1 => 8), Dense(8 => 8), Dense(8 => 16)))
4559

46-
@inferred deeponet((u, v), ps, st)
47-
@jet deeponet((u, v), ps, st)
60+
ps, st = Lux.setup(rng, deeponet) |> dev
61+
@test_throws ArgumentError deeponet((u, y), ps, st)
4862
end
4963
end
5064
end

0 commit comments

Comments
 (0)