Skip to content
This repository was archived by the owner on Sep 28, 2024. It is now read-only.

Commit dc018ec

Browse files
committed
fix wrong dim on permutation and shrink test
1 parent 8e6e4a5 commit dc018ec

File tree

2 files changed

+45
-11
lines changed

2 files changed

+45
-11
lines changed

src/fourier.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ Base.ndims(::SpectralConv{N}) where {N} = N
3434
spectral_conv(𝐱₁, 𝐱₂) = @tullio 𝐲[m, o, b] := 𝐱₁[m, i, b] * 𝐱₂[o, i, m]
3535

3636
function (m::SpectralConv)(𝐱::AbstractArray)
37-
𝐱ᵀ = permutedims(Zygote.hook(real, 𝐱), (ndims(m)+1, 1:ndims(m)..., ndims(m)+2)) # [x, in_chs, batch] <- [in_chs, x, batch]
37+
𝐱ᵀ = permutedims(Zygote.hook(real, 𝐱), (2:ndims(m)+1..., 1, ndims(m)+2)) # [x, in_chs, batch] <- [in_chs, x, batch]
3838
𝐱_fft = fft(𝐱ᵀ, 1:ndims(m)) # [x, in_chs, batch]
3939

4040
# [modes, out_chs, batch] <- [modes, in_chs, batch] * [out_chs, in_chs, modes]
@@ -48,7 +48,7 @@ function (m::SpectralConv)(𝐱::AbstractArray)
4848
𝐱_padded = cat(𝐱_shaped, pad, dims=1:ndims(m))
4949

5050
𝐱_out = ifft(𝐱_padded, 1:ndims(m)) # [x, out_chs, batch]
51-
𝐱_outᵀ = permutedims(real(𝐱_out), (2:ndims(m)+1..., 1, ndims(m)+2)) # [out_chs, x, batch] <- [x, out_chs, batch]
51+
𝐱_outᵀ = permutedims(real(𝐱_out), (ndims(m)+1, 1:ndims(m)..., ndims(m)+2)) # [out_chs, x, batch] <- [x, out_chs, batch]
5252

5353
return m.σ.(𝐱_outᵀ)
5454
end

test/fourier.jl

Lines changed: 43 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
@testset "SpectralConv" begin
1+
@testset "SpectralConv1d" begin
22
modes = (16, )
33
ch = 64 => 64
44

@@ -8,16 +8,15 @@
88
)
99
@test ndims(SpectralConv(ch, modes)) == 1
1010

11-
𝐱, _ = get_burgers_data(n=1000)
12-
@test size(m(𝐱)) == (64, 1024, 1000)
11+
𝐱, _ = get_burgers_data(n=5)
12+
@test size(m(𝐱)) == (64, 1024, 5)
1313

14-
T = Float32
1514
loss(x, y) = Flux.mse(m(x), y)
16-
data = [(T.(𝐱[:, :, 1:5]), rand(T, 64, 1024, 5))]
15+
data = [(𝐱, rand(Float32, 64, 1024, 5))]
1716
Flux.train!(loss, params(m), data, Flux.ADAM())
1817
end
1918

20-
@testset "FourierOperator" begin
19+
@testset "FourierOperator1d" begin
2120
modes = (16, )
2221
ch = 64 => 64
2322

@@ -26,10 +25,45 @@ end
2625
FourierOperator(ch, modes)
2726
)
2827

29-
𝐱, _ = get_burgers_data(n=1000)
30-
@test size(m(𝐱)) == (64, 1024, 1000)
28+
𝐱, _ = get_burgers_data(n=5)
29+
@test size(m(𝐱)) == (64, 1024, 5)
3130

3231
loss(x, y) = Flux.mse(m(x), y)
33-
data = [(Float32.(𝐱[:, :, 1:5]), rand(Float32, 64, 1024, 5))]
32+
data = [(𝐱, rand(Float32, 64, 1024, 5))]
33+
Flux.train!(loss, params(m), data, Flux.ADAM())
34+
end
35+
36+
@testset "SpectralConv2d" begin
37+
modes = (16, 16)
38+
ch = 64 => 64
39+
40+
m = Chain(
41+
Dense(1, 64),
42+
SpectralConv(ch, modes)
43+
)
44+
@test ndims(SpectralConv(ch, modes)) == 2
45+
46+
𝐱, _, _, _ = get_darcy_flow_data(n=5, Δsamples=20)
47+
@test size(m(𝐱)) == (64, 22, 22, 5)
48+
49+
loss(x, y) = Flux.mse(m(x), y)
50+
data = [(𝐱, rand(Float32, 64, 22, 22, 5))]
51+
Flux.train!(loss, params(m), data, Flux.ADAM())
52+
end
53+
54+
@testset "FourierOperator2d" begin
55+
modes = (16, 16)
56+
ch = 64 => 64
57+
58+
m = Chain(
59+
Dense(1, 64),
60+
FourierOperator(ch, modes)
61+
)
62+
63+
𝐱, _, _, _ = get_darcy_flow_data(n=5, Δsamples=20)
64+
@test size(m(𝐱)) == (64, 22, 22, 5)
65+
66+
loss(x, y) = Flux.mse(m(x), y)
67+
data = [(𝐱, rand(Float32, 64, 22, 22, 5))]
3468
Flux.train!(loss, params(m), data, Flux.ADAM())
3569
end

0 commit comments

Comments
 (0)