Skip to content
This repository was archived by the owner on Sep 28, 2024. It is now read-only.

Commit 9e4a069

Browse files
committed
Fix deeponet
1 parent 67d9007 commit 9e4a069

File tree

7 files changed

+98
-98
lines changed

7 files changed

+98
-98
lines changed

.buildkite/pipeline.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,4 +52,4 @@ steps:
5252
env:
5353
RETESTITEMS_NWORKERS: 4
5454
RETESTITEMS_NWORKER_THREADS: 2
55-
SECRET_CODECOV_TOKEN: "Tg/DGJmBhzxJQBcGajfE2McAOuNVa6zpMZGw0rYTTTGpE7dsBg8cDuj5D9tmLZYdNXJxlkSrjQkjQiPelqECIlMieRveDJ/S3bnA1meJk5p8/PIzwJzQiMCrXpX+xbhcHPn9aQoMmloqP/u6eJ7ToYineDiGbtvQnofVvH0cTgEj/xD15Dflo3K9m/w5/vdvaRbSrxIMc1Z7md/m2XSJJHyLD2Zkir2YWk2cZpyq/S7mA0zL2Yeur27tkzsjSPN/Y+vS5+LLdr5yxo9OVTCAJAZDVsBJGf1Ynd8y4T7usfK+fa41Se48ZpKA/VZtSSZQKdTHM0JcVpqe+Z5L9zbGGg==;U2FsdGVkX18VAT6PhLvJvEVkHs4vFg/vBLTECZAdWznsrPEISjpgl00GTYqrxMw30trS4RDWRSdY1TRYAC85QQ=="
55+
SECRET_CODECOV_TOKEN: "vn5M+4wSwUFje6fl6UB/Q/rTmLHu3OlCCMgoPOXPQHYpLZTLz2hOHsV44MadAnxw8MsNVxLKZlXBKqP3IydU9gUfV7QUBtnvbUmIvgUHbr+r0bVaIVVhw6cnd0s8/b+561nU483eRJd35bjYDOlO+V5eDxkbdh/0bzLefXNXy5+ALxsBYzsp75Sx/9nuREfRqWwU6S45mne2ZlwCDpZlFvBDXQ2ICKYXpA45MpxhW9RuqfpQdi6sSR6I/HdHkV2cuJO99dqqh8xfUy6vWPC/+HUVrn9ETsrXtayX1MX3McKj869htGICpR8vqd311HTONYVprH2AN1bJqr5MOIZ8Xg==;U2FsdGVkX1+W55pTI7zq+NwYrbK6Cgqe+Gp8wMCmXY+W10aXTB0bS6zshiDYSQ1Y3piT91xFyNhS+9AsajY0yQ=="

Project.toml

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,7 @@ ConcreteStructs = "0.2.3"
3131
Documenter = "1.4.1"
3232
ExplicitImports = "1.6.0"
3333
FFTW = "1.8.0"
34-
Lux = "0.5.53"
35-
LuxAMDGPU = "0.2.3"
34+
Lux = "0.5.56"
3635
LuxCUDA = "0.3.2"
3736
LuxCore = "0.1.15"
3837
LuxTestUtils = "0.1.15"
@@ -43,7 +42,6 @@ Random = "1.10"
4342
ReTestItems = "1.24.0"
4443
Reexport = "1.2.2"
4544
StableRNGs = "1.0.2"
46-
Statistics = "1.10"
4745
Test = "1.10"
4846
WeightInitializers = "0.1.7"
4947
Zygote = "0.6.70"
@@ -53,15 +51,13 @@ julia = "1.10"
5351
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
5452
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
5553
ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7"
56-
LuxAMDGPU = "83120cb1-ca15-4f04-bf3b-6967d2e6b60b"
5754
LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda"
5855
LuxTestUtils = "ac9de150-d08f-4546-94fb-7472b5760531"
5956
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
6057
ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823"
6158
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
62-
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
6359
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
6460
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
6561

6662
[targets]
67-
test = ["Aqua", "Documenter", "ExplicitImports", "LuxAMDGPU", "LuxCUDA", "LuxTestUtils", "Optimisers", "ReTestItems", "StableRNGs", "Statistics", "Test", "Zygote"]
63+
test = ["Aqua", "Documenter", "ExplicitImports", "AMDGPU", "LuxCUDA", "LuxTestUtils", "Optimisers", "ReTestItems", "StableRNGs", "Test", "Zygote"]

src/deeponet.jl

Lines changed: 46 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -21,27 +21,31 @@ operators", doi: https://arxiv.org/abs/1910.03193
2121
## Example
2222
2323
```jldoctest
24-
deeponet = DeepONet(; branch=(64, 32, 32, 16), trunk=(1, 8, 8, 16))
25-
26-
# output
27-
28-
Branch net :
29-
(
30-
Chain(
24+
julia> deeponet = DeepONet(; branch=(64, 32, 32, 16), trunk=(1, 8, 8, 16))
25+
@compact(
26+
branch = Chain(
3127
layer_1 = Dense(64 => 32), # 2_080 parameters
3228
layer_2 = Dense(32 => 32), # 1_056 parameters
3329
layer_3 = Dense(32 => 16), # 528 parameters
3430
),
35-
)
36-
37-
Trunk net :
38-
(
39-
Chain(
31+
trunk = Chain(
4032
layer_1 = Dense(1 => 8), # 16 parameters
4133
layer_2 = Dense(8 => 8), # 72 parameters
4234
layer_3 = Dense(8 => 16), # 144 parameters
4335
),
44-
)
36+
) do (u, y)
37+
t = trunk(y)
38+
b = branch(u)
39+
@argcheck ndims(t) == ndims(b) + 1 || ndims(t) == ndims(b)
40+
@argcheck size(t, 1) == size(b, 1) "Branch and Trunk net must share the same amount of nodes in the last layer. Otherwise Σᵢ bᵢⱼ tᵢₖ won't work."
41+
b_ = if ndims(t) == ndims(b)
42+
b
43+
else
44+
reshape(b, size(b, 1), 1, (size(b))[2:end]...)
45+
end
46+
return dropdims(sum(t .* b_; dims = 1); dims = 1)
47+
end # Total: 3_896 parameters,
48+
# plus 0 states.
4549
```
4650
"""
4751
function DeepONet(; branch=(64, 32, 32, 16), trunk=(1, 8, 8, 16),
@@ -81,54 +85,48 @@ operators", doi: https://arxiv.org/abs/1910.03193
8185
## Example
8286
8387
```jldoctest
84-
branch_net = Chain(Dense(64 => 32), Dense(32 => 32), Dense(32 => 16));
85-
trunk_net = Chain(Dense(1 => 8), Dense(8 => 8), Dense(8 => 16));
86-
don_ = DeepONet(branch_net, trunk_net)
88+
julia> branch_net = Chain(Dense(64 => 32), Dense(32 => 32), Dense(32 => 16));
8789
88-
# output
90+
julia> trunk_net = Chain(Dense(1 => 8), Dense(8 => 8), Dense(8 => 16));
8991
90-
Branch net :
91-
(
92-
Chain(
92+
julia> deeponet = DeepONet(branch_net, trunk_net)
93+
@compact(
94+
branch = Chain(
9395
layer_1 = Dense(64 => 32), # 2_080 parameters
9496
layer_2 = Dense(32 => 32), # 1_056 parameters
9597
layer_3 = Dense(32 => 16), # 528 parameters
9698
),
97-
)
98-
99-
Trunk net :
100-
(
101-
Chain(
99+
trunk = Chain(
102100
layer_1 = Dense(1 => 8), # 16 parameters
103101
layer_2 = Dense(8 => 8), # 72 parameters
104102
layer_3 = Dense(8 => 16), # 144 parameters
105103
),
106-
)
104+
) do (u, y)
105+
t = trunk(y)
106+
b = branch(u)
107+
@argcheck ndims(t) == ndims(b) + 1 || ndims(t) == ndims(b)
108+
@argcheck size(t, 1) == size(b, 1) "Branch and Trunk net must share the same amount of nodes in the last layer. Otherwise Σᵢ bᵢⱼ tᵢₖ won't work."
109+
b_ = if ndims(t) == ndims(b)
110+
b
111+
else
112+
reshape(b, size(b, 1), 1, (size(b))[2:end]...)
113+
end
114+
return dropdims(sum(t .* b_; dims = 1); dims = 1)
115+
end # Total: 3_896 parameters,
116+
# plus 0 states.
107117
```
108118
"""
109119
function DeepONet(branch::L1, trunk::L2) where {L1, L2}
110-
return @compact(; branch, trunk, dispatch=:DeepONet) do (u, y) # ::AbstractArray{<:Real, M} where {M}
111-
t = trunk(y) # p x N x nb
112-
b = branch(u) # p x nb
113-
114-
# checks for last dimension size
115-
@argcheck size(t, 1)==size(b, 1) "Branch and Trunk net must share the same amount \
116-
of nodes in the last layer. Otherwise Σᵢ bᵢⱼ tᵢₖ \
117-
won't work."
118-
119-
tᵀ = permutedims(t, (2, 1, 3)) # N x p x nb
120-
b_ = permutedims(reshape(b, size(b)..., 1), (1, 3, 2)) # p x 1 x nb
121-
G = batched_mul(tᵀ, b_) # N x 1 X nb
122-
@return dropdims(G; dims=2)
123-
end
124-
end
120+
return @compact(; branch, trunk, dispatch=:DeepONet) do (u, y)
121+
t = trunk(y) # p x N x nb...
122+
b = branch(u) # p x nb...
125123

126-
function Base.show(io::IO, model::Lux.CompactLuxLayer{:DeepONet})
127-
Lux._print_wrapper_model(io, "Branch net :\n", model.layers.branch)
128-
print(io, "\n \n")
129-
Lux._print_wrapper_model(io, "Trunk net :\n", model.layers.trunk)
130-
end
124+
@argcheck ndims(t) == ndims(b) + 1 || ndims(t) == ndims(b)
125+
@argcheck size(t, 1)==size(b, 1) "Branch and Trunk net must share the same \
126+
amount of nodes in the last layer. Otherwise \
127+
Σᵢ bᵢⱼ tᵢₖ won't work."
131128

132-
function Base.show(io::IO, ::MIME"text/plain", x::CompactLuxLayer{:DeepONet})
133-
show(io, x)
129+
b_ = ndims(t) == ndims(b) ? b : reshape(b, size(b, 1), 1, size(b)[2:end]...)
130+
@return dropdims(sum(t .* b_; dims=1); dims=1)
131+
end
134132
end

test/deeponet_tests.jl

Lines changed: 29 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,48 @@
11
@testitem "DeepONet" setup=[SharedTestSetup] begin
22
@testset "BACKEND: $(mode)" for (mode, aType, dev, ongpu) in MODES
3-
rng_ = get_stable_rng()
3+
rng = StableRNG(12345)
44

5-
u = rand(64, 5) |> aType # sensor_points x nb
6-
y = rand(1, 10, 5) |> aType # ndims x N x nb
5+
u = rand(Float32, 64, 5) |> aType # sensor_points x nb
6+
y = rand(Float32, 1, 10, 5) |> aType # ndims x N x nb
77
out_size = (10, 5)
88

9-
don_ = DeepONet(; branch=(64, 32, 32, 16), trunk=(1, 8, 8, 16))
9+
deeponet = DeepONet(; branch=(64, 32, 32, 16), trunk=(1, 8, 8, 16))
1010

11-
ps, st = Lux.setup(rng_, don_) |> dev
11+
ps, st = Lux.setup(rng, deeponet) |> dev
1212

13-
@inferred don_((u, y), ps, st)
14-
@jet don_((u, y), ps, st)
13+
@inferred deeponet((u, y), ps, st)
14+
@jet deeponet((u, y), ps, st)
1515

16-
pred = first(don_((u, y), ps, st))
16+
pred = first(deeponet((u, y), ps, st))
1717
@test size(pred) == out_size
1818

19-
don_ = DeepONet(Chain(Dense(64 => 32), Dense(32 => 32), Dense(32 => 16)),
19+
deeponet = DeepONet(Chain(Dense(64 => 32), Dense(32 => 32), Dense(32 => 16)),
2020
Chain(Dense(1 => 8), Dense(8 => 8), Dense(8 => 16)))
2121

22-
ps, st = Lux.setup(rng_, don_) |> dev
22+
ps, st = Lux.setup(rng, deeponet) |> dev
2323

24-
@inferred don_((u, y), ps, st)
25-
@jet don_((u, y), ps, st)
24+
@inferred deeponet((u, y), ps, st)
25+
@jet deeponet((u, y), ps, st)
2626

27-
pred = first(don_((u, y), ps, st))
27+
pred = first(deeponet((u, y), ps, st))
2828
@test size(pred) == out_size
2929

30-
don_ = DeepONet(Chain(Dense(64 => 32), Dense(32 => 32), Dense(32 => 20)),
30+
deeponet = DeepONet(Chain(Dense(64 => 32), Dense(32 => 32), Dense(32 => 20)),
3131
Chain(Dense(1 => 8), Dense(8 => 8), Dense(8 => 16)))
32-
ps, st = Lux.setup(rng_, don_) |> dev
33-
@test_throws ArgumentError don_((u, y), ps, st)
32+
ps, st = Lux.setup(rng, deeponet) |> dev
33+
@test_throws ArgumentError deeponet((u, y), ps, st)
34+
35+
@testset "higher-dim input #7" begin
36+
u = ones(Float32, 10, 10, 10) |> aType
37+
v = ones(Float32, 1, 10, 10) |> aType
38+
deeponet = DeepONet(; branch=(10, 10, 10), trunk=(1, 10, 10))
39+
ps, st = Lux.setup(rng, deeponet) |> dev
40+
41+
y, st_ = deeponet((u, v), ps, st)
42+
@test size(y) == (10, 10)
43+
44+
@inferred deeponet((u, v), ps, st)
45+
@jet deeponet((u, v), ps, st)
46+
end
3447
end
3548
end

test/fno_tests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
@testitem "Fourier Neural Operator" setup=[SharedTestSetup] begin
22
@testset "BACKEND: $(mode)" for (mode, aType, dev, ongpu) in MODES
3-
rng = get_stable_rng()
3+
rng = StableRNG(12345)
44

55
setups = [
66
(modes=(16,), chs=(2, 64, 64, 64, 64, 64, 128, 1),

test/layers_tests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
@testitem "SpectralConv & SpectralKernel" setup=[SharedTestSetup] begin
22
@testset "BACKEND: $(mode)" for (mode, aType, dev, ongpu) in MODES
3-
rng = get_stable_rng()
3+
rng = StableRNG(12345)
44

55
opconv = [SpectralConv, SpectralKernel]
66
setups = [

test/shared_testsetup.jl

Lines changed: 18 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,28 @@
11
@testsetup module SharedTestSetup
22
import Reexport: @reexport
33

4-
@reexport using Lux, LuxCUDA, LuxAMDGPU, Zygote, Optimisers, Random, StableRNGs, Statistics
4+
@reexport using Lux, LuxCUDA, AMDGPU, Zygote, Optimisers, Random, StableRNGs
55
using LuxTestUtils: @jet, @test_gradients
66

77
CUDA.allowscalar(false)
88

99
const BACKEND_GROUP = get(ENV, "BACKEND_GROUP", "All")
1010

1111
cpu_testing() = BACKEND_GROUP == "All" || BACKEND_GROUP == "CPU"
12-
cuda_testing() = (BACKEND_GROUP == "All" || BACKEND_GROUP == "CUDA") && LuxCUDA.functional()
12+
function cuda_testing()
13+
return (BACKEND_GROUP == "All" || BACKEND_GROUP == "CUDA") &&
14+
LuxDeviceUtils.functional(LuxCUDADevice)
15+
end
1316
function amdgpu_testing()
14-
return (BACKEND_GROUP == "All" || BACKEND_GROUP == "AMDGPU") && LuxAMDGPU.functional()
17+
return (BACKEND_GROUP == "All" || BACKEND_GROUP == "AMDGPU") &&
18+
LuxDeviceUtils.functional(LuxAMDGPUDevice)
1519
end
1620

1721
const MODES = begin
18-
# Mode, Array Type, Device Function, GPU?
19-
cpu_mode = ("CPU", Array, LuxCPUDevice(), false)
20-
cuda_mode = ("CUDA", CuArray, LuxCUDADevice(), true)
21-
amdgpu_mode = ("AMDGPU", ROCArray, LuxAMDGPUDevice(), true)
22-
2322
modes = []
24-
cpu_testing() && push!(modes, cpu_mode)
25-
cuda_testing() && push!(modes, cuda_mode)
26-
amdgpu_testing() && push!(modes, amdgpu_mode)
27-
23+
cpu_testing() && push!(modes, ("CPU", Array, LuxCPUDevice(), false))
24+
cuda_testing() && push!(modes, ("CUDA", CuArray, LuxCUDADevice(), true))
25+
amdgpu_testing() && push!(modes, ("AMDGPU", ROCArray, LuxAMDGPUDevice(), true))
2826
modes
2927
end
3028

@@ -36,28 +34,23 @@ function get_default_rng(mode::String)
3634
return rng isa TaskLocalRNG ? copy(rng) : deepcopy(rng)
3735
end
3836

39-
get_stable_rng(seed=12345) = StableRNG(seed)
37+
train!(args...; kwargs...) = train!(MSELoss(), AutoZygote(), args...; kwargs...)
4038

41-
default_loss_function(model, ps, x, y) = mean(abs2, y .- model(x, ps))
39+
function train!(loss, backend, model, ps, st, data; epochs=10)
40+
l1 = loss(model, ps, st, first(data))
4241

43-
train!(args...; kwargs...) = train!(default_loss_function, args...; kwargs...)
44-
45-
function train!(loss, model, ps, st, data; epochs=10)
46-
m = StatefulLuxLayer{true}(model, ps, st)
47-
48-
l1 = loss(m, ps, first(data)...)
49-
st_opt = Optimisers.setup(Adam(0.01f0), ps)
42+
tstate = Lux.Experimental.TrainState(model, ps, st, Adam(0.01f0))
5043
for _ in 1:epochs, (x, y) in data
51-
_, gs, _, _ = Zygote.gradient(loss, m, ps, x, y)
52-
Optimisers.update!(st_opt, ps, gs)
44+
_, _, _, tstate = Lux.Experimental.single_train_step!(backend, loss, (x, y), tstate)
5345
end
54-
l2 = loss(m, ps, first(data)...)
46+
47+
l2 = loss(model, ps, st, first(data))
5548

5649
return l2, l1
5750
end
5851

5952
export @jet, @test_gradients, check_approx
6053
export BACKEND_GROUP, MODES, cpu_testing, cuda_testing, amdgpu_testing, get_default_rng,
61-
get_stable_rng, train!
54+
train!
6255

6356
end

0 commit comments

Comments
 (0)