Skip to content
This repository was archived by the owner on Sep 28, 2024. It is now read-only.

Commit 56d5d00

Browse files
committed
test: display layers
1 parent 9b917c0 commit 56d5d00

File tree

3 files changed

+7
-4
lines changed

3 files changed

+7
-4
lines changed

test/deeponet_tests.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
out_size = (10, 5)
88

99
deeponet = DeepONet(; branch=(64, 32, 32, 16), trunk=(1, 8, 8, 16))
10-
10+
display(deeponet)
1111
ps, st = Lux.setup(rng, deeponet) |> dev
1212

1313
@inferred deeponet((u, y), ps, st)
@@ -18,7 +18,7 @@
1818

1919
deeponet = DeepONet(Chain(Dense(64 => 32), Dense(32 => 32), Dense(32 => 16)),
2020
Chain(Dense(1 => 8), Dense(8 => 8), Dense(8 => 16)))
21-
21+
display(deeponet)
2222
ps, st = Lux.setup(rng, deeponet) |> dev
2323

2424
@inferred deeponet((u, y), ps, st)
@@ -29,13 +29,15 @@
2929

3030
deeponet = DeepONet(Chain(Dense(64 => 32), Dense(32 => 32), Dense(32 => 20)),
3131
Chain(Dense(1 => 8), Dense(8 => 8), Dense(8 => 16)))
32+
display(deeponet)
3233
ps, st = Lux.setup(rng, deeponet) |> dev
3334
@test_throws ArgumentError deeponet((u, y), ps, st)
3435

3536
@testset "higher-dim input #7" begin
3637
u = ones(Float32, 10, 10, 10) |> aType
3738
v = ones(Float32, 1, 10, 10) |> aType
3839
deeponet = DeepONet(; branch=(10, 10, 10), trunk=(1, 10, 10))
40+
display(deeponet)
3941
ps, st = Lux.setup(rng, deeponet) |> dev
4042

4143
y, st_ = deeponet((u, v), ps, st)

test/fno_tests.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,12 @@
1010

1111
@testset "$(length(setup.modes))D: permuted = $(setup.permuted)" for setup in setups
1212
fno = FourierNeuralOperator(; setup.chs, setup.modes, setup.permuted)
13+
display(fno)
14+
ps, st = Lux.setup(rng, fno) |> dev
1315

1416
x = rand(rng, Float32, setup.x_size...) |> aType
1517
y = rand(rng, Float32, setup.y_size...) |> aType
1618

17-
ps, st = Lux.setup(rng, fno) |> dev
18-
1919
@inferred fno(x, ps, st)
2020
@jet fno(x, ps, st)
2121

test/layers_tests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
l1 = p ? Conv(ntuple(_ -> 1, length(setup.m)), in_chs => first(ch)) :
2323
Dense(in_chs => first(ch))
2424
m = Chain(l1, op(ch, setup.m; setup.permuted))
25+
display(m)
2526
ps, st = Lux.setup(rng, m) |> dev
2627

2728
x = rand(rng, Float32, setup.x_size...) |> aType

0 commit comments

Comments
 (0)