Skip to content
This repository was archived by the owner on Sep 28, 2024. It is now read-only.

Commit c394788

Browse files
committed
add test for train
1 parent d1a536b commit c394788

File tree

2 files changed

+7
-2
lines changed

2 files changed

+7
-2
lines changed

src/fourier.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ using Tullio
44

55
export
66
SpectralConv1d,
7+
FourierOperator,
78
FNO
89

910
struct SpectralConv1d{T,S}
@@ -74,5 +75,3 @@ function FNO()
7475
flatten
7576
)
7677
end
77-
78-
# loss(m::SpectralConv1d, x, x̂) = sum(abs2, x̂ .- m(x)) / len

test/fourier.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,5 +15,11 @@ end
1515

1616
@testset "FNO" begin
1717
𝐱, 𝐲 = get_data()
18+
𝐱, 𝐲 = Float32.(𝐱), Float32.(𝐲)
1819
@test size(FNO()(𝐱)) == size(𝐲)
20+
21+
# m = FNO()
22+
# loss(𝐱, 𝐲) = sum(abs2, 𝐲 .- m(𝐱)) / size(𝐱)[end]
23+
# data = [(𝐱[:, :, 1:5], 𝐲[:, 1:5])]
24+
# Flux.train!(loss, params(m), data, Flux.ADAM())
1925
end

0 commit comments

Comments
 (0)