Skip to content
This repository was archived by the owner on Sep 28, 2024. It is now read-only.

Commit 3909c71

Browse files
committed
fix wrong dim on permutation
1 parent 8e6e4a5 commit 3909c71

File tree

3 files changed

+42
-6
lines changed

3 files changed

+42
-6
lines changed

src/fourier.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ Base.ndims(::SpectralConv{N}) where {N} = N
3434
spectral_conv(𝐱₁, 𝐱₂) = @tullio 𝐲[m, o, b] := 𝐱₁[m, i, b] * 𝐱₂[o, i, m]
3535

3636
function (m::SpectralConv)(𝐱::AbstractArray)
37-
𝐱ᵀ = permutedims(Zygote.hook(real, 𝐱), (ndims(m)+1, 1:ndims(m)..., ndims(m)+2)) # [x, in_chs, batch] <- [in_chs, x, batch]
37+
𝐱ᵀ = permutedims(Zygote.hook(real, 𝐱), (2:ndims(m)+1..., 1, ndims(m)+2)) # [x, in_chs, batch] <- [in_chs, x, batch]
3838
𝐱_fft = fft(𝐱ᵀ, 1:ndims(m)) # [x, in_chs, batch]
3939

4040
# [modes, out_chs, batch] <- [modes, in_chs, batch] * [out_chs, in_chs, modes]
@@ -48,7 +48,7 @@ function (m::SpectralConv)(𝐱::AbstractArray)
4848
𝐱_padded = cat(𝐱_shaped, pad, dims=1:ndims(m))
4949

5050
𝐱_out = ifft(𝐱_padded, 1:ndims(m)) # [x, out_chs, batch]
51-
𝐱_outᵀ = permutedims(real(𝐱_out), (2:ndims(m)+1..., 1, ndims(m)+2)) # [out_chs, x, batch] <- [x, out_chs, batch]
51+
𝐱_outᵀ = permutedims(real(𝐱_out), (ndims(m)+1, 1:ndims(m)..., ndims(m)+2)) # [out_chs, x, batch] <- [x, out_chs, batch]
5252

5353
return m.σ.(𝐱_outᵀ)
5454
end

test/fourier.jl

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
@testset "SpectralConv" begin
1+
@testset "SpectralConv1d" begin
22
modes = (16, )
33
ch = 64 => 64
44

@@ -17,7 +17,7 @@
1717
Flux.train!(loss, params(m), data, Flux.ADAM())
1818
end
1919

20-
@testset "FourierOperator" begin
20+
@testset "FourierOperator1d" begin
2121
modes = (16, )
2222
ch = 64 => 64
2323

@@ -33,3 +33,39 @@ end
3333
data = [(Float32.(𝐱[:, :, 1:5]), rand(Float32, 64, 1024, 5))]
3434
Flux.train!(loss, params(m), data, Flux.ADAM())
3535
end
36+
37+
@testset "SpectralConv2d" begin
38+
modes = (16, 16)
39+
ch = 64 => 64
40+
41+
m = Chain(
42+
Dense(1, 64),
43+
SpectralConv(ch, modes)
44+
)
45+
@test ndims(SpectralConv(ch, modes)) == 2
46+
47+
𝐱, _ , _, _ = get_darcy_flow_data()
48+
@test size(m(𝐱)) == (64, 85, 85, 1024)
49+
50+
T = Float32
51+
loss(x, y) = Flux.mse(m(x), y)
52+
data = [(T.(𝐱[:, :, :, 1:5]), rand(T, 64, 85, 85, 5))]
53+
Flux.train!(loss, params(m), data, Flux.ADAM())
54+
end
55+
56+
@testset "FourierOperator2d" begin
57+
modes = (16, 16)
58+
ch = 64 => 64
59+
60+
m = Chain(
61+
Dense(1, 64),
62+
FourierOperator(ch, modes)
63+
)
64+
65+
𝐱, _ , _, _ = get_darcy_flow_data()
66+
@test size(m(𝐱)) == (64, 85, 85, 1024)
67+
68+
loss(x, y) = Flux.mse(m(x), y)
69+
data = [(Float32.(𝐱[:, :, :, 1:5]), rand(Float32, 64, 85, 85, 5))]
70+
Flux.train!(loss, params(m), data, Flux.ADAM())
71+
end

test/runtests.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ using Test
33
using Flux
44

55
@testset "NeuralOperators.jl" begin
6-
include("data.jl")
6+
# include("data.jl")
77
include("fourier.jl")
8-
include("model.jl")
8+
# include("model.jl")
99
end

0 commit comments

Comments
 (0)