Skip to content
This repository was archived by the owner on Sep 28, 2024. It is now read-only.

Commit 609e6dd

Browse files
authored
Merge pull request #7 from foldfelis/gpu
support GPU
2 parents 6a390d9 + 1817c9d commit 609e6dd

File tree

6 files changed

+37
-23
lines changed

6 files changed

+37
-23
lines changed

Project.toml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,25 @@ authors = ["JingYu Ning <[email protected]> and contributors"]
44
version = "0.1.0"
55

66
[deps]
7+
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
8+
CUDAKernels = "72cfdca4-0801-4ab0-bf6a-d52aa10adc57"
79
DataDeps = "124859b0-ceae-595e-8997-d05f6a7a8dfe"
810
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
911
Fetch = "bb354801-46f6-40b6-9c3d-d42d7a74c775"
1012
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
13+
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
1114
MAT = "23992714-dd62-5051-b70f-ba57cb901cac"
1215
Tullio = "bc48ee85-29a4-5162-ae0b-a64e1601d4bc"
1316
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
1417

1518
[compat]
19+
CUDA = "3.3"
20+
CUDAKernels = "0.3"
1621
DataDeps = "0.7"
1722
FFTW = "1.4"
1823
Fetch = "0.1"
1924
Flux = "0.12"
25+
KernelAbstractions = "0.7"
2026
MAT = "0.10"
2127
Tullio = "0.3"
2228
Zygote = "0.6"

docs/Manifest.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908"
5959

6060
[[NeuralOperators]]
6161
path = ".."
62-
uuid = "9ab867d4-5049-4b07-85bc-95379d8d6d9c"
62+
uuid = "ea5c82af-86e5-48da-8ee1-382d6ad7af4b"
6363
version = "0.1.0"
6464

6565
[[Parsers]]

docs/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
[deps]
22
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
3-
NeuralOperators = "9ab867d4-5049-4b07-85bc-95379d8d6d9c"
3+
NeuralOperators = "ea5c82af-86e5-48da-8ee1-382d6ad7af4b"

example/burgers.jl

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
using NeuralOperators
22
using Flux
3-
# using CUDA
3+
using CUDA
44

5-
# if has_cuda()
6-
# @info "CUDA is on"
7-
# device = gpu
8-
# CUDA.allowscalar(false)
9-
# else
5+
if has_cuda()
6+
@info "CUDA is on"
7+
device = gpu
8+
CUDA.allowscalar(false)
9+
else
1010
device = cpu
11-
# end
11+
end
1212

1313
m = FourierNeuralOperator() |> device
1414
loss(𝐱, 𝐲) = sum(abs2, 𝐲 .- m(𝐱)) / size(𝐱)[end]
@@ -24,13 +24,14 @@ n_test = 40
2424
loader_test = Flux.DataLoader((𝐱_test, 𝐲_test), batchsize=20, shuffle=false)
2525

2626
function loss_test()
27-
l = 0
27+
l = 0f0
2828
for (𝐱, 𝐲) in loader_test
29+
𝐱, 𝐲 = device(𝐱), device(𝐲)
2930
l += loss(𝐱, 𝐲)
3031
end
3132
@info "loss: $(l/length(loader_test))"
3233
end
3334

3435
data = [(𝐱, 𝐲) for (𝐱, 𝐲) in loader_train] |> device
3536
opt = Flux.Optimiser(WeightDecay(1f-4), Flux.ADAM(1f-3))
36-
Flux.@epochs 500 @time(Flux.train!(loss, params(m), data, opt, cb=Flux.throttle(loss_test, 10)))
37+
Flux.@epochs 500 @time(Flux.train!(loss, params(m), data, opt, cb=Flux.throttle(loss_test, 5)))

src/NeuralOperators.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@ module NeuralOperators
66
using Flux
77
using FFTW
88
using Tullio
9+
using CUDA
10+
using CUDAKernels
11+
using KernelAbstractions
912
using Zygote
1013

1114
function __init__()

src/fourier.jl

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
export
22
SpectralConv1d,
3-
FourierOperator,
4-
FNO
3+
FourierOperator
54

65
struct SpectralConv1d{T, S}
76
weight::T
@@ -28,26 +27,31 @@ function SpectralConv1d(
2827

2928
return Chain(
3029
x -> Zygote.hook(real, x),
31-
SpectralConv1d(weights, in_chs, out_chs, modes, σ)
30+
SpectralConv1d(weights, in_chs, out_chs, modes, σ),
3231
)
3332
end
3433

3534
Flux.@functor SpectralConv1d
3635

36+
t(𝐱) = @tullio 𝐱ᵀ[i, j, k] := 𝐱[j, i, k]
37+
ein_mul(𝐱₁, 𝐱₂) = @tullio 𝐲[m, o, b] := 𝐱₁[m, i, b] * 𝐱₂[o, i, m]
38+
3739
function (m::SpectralConv1d)(𝐱::AbstractArray)
38-
𝐱_fft = fft(𝐱, 2) # [in_chs, x, batch]
39-
𝐱_selected = 𝐱_fft[:, 1:m.modes, :] # [in_chs, modes, batch]
40+
𝐱ᵀ = t(𝐱) # [x, in_chs, batch] <- [in_chs, x, batch]
41+
𝐱_fft = fft(𝐱ᵀ, 1) # [x, in_chs, batch]
42+
𝐱_selected = 𝐱_fft[1:m.modes, :, :] # [modes, in_chs, batch]
4043

41-
# [out_chs, modes, batch] <- [in_chs, modes, batch] [out_chs, in_chs, modes]
42-
@tullio 𝐱_weighted[o, m, b] := 𝐱_selected[i, m, b] * m.weight[o, i, m]
44+
# [modes, out_chs, batch] <- [modes, in_chs, batch] * [out_chs, in_chs, modes]
45+
𝐱_weighted = ein_mul(𝐱_selected, m.weight)
4346

44-
s = size(𝐱_weighted)
45-
d = size(𝐱, 2) - m.modes
46-
𝐱_padded = cat(𝐱_weighted, zeros(ComplexF32, s[1], d, s[3:end]...), dims=2)
47+
s = size(𝐱_weighted)[2:end]
48+
d = size(𝐱ᵀ, 1) - m.modes
49+
𝐱_padded = cat(𝐱_weighted, zeros(ComplexF32, d, s...), dims=1)
4750

48-
𝐱_out = ifft(𝐱_padded, 2)
51+
𝐱_out = ifft(𝐱_padded, 1) # [x, out_chs, batch]
52+
𝐱_outᵀ = t(𝐱_out) # [out_chs, x, batch] <- [x, out_chs, batch]
4953

50-
return m.σ.(real(𝐱_out))
54+
return m.σ.(real(𝐱_outᵀ))
5155
end
5256

5357
function FourierOperator(

0 commit comments

Comments
 (0)