Skip to content
This repository was archived by the owner on Sep 28, 2024. It is now read-only.

Commit e56f5d6

Browse files
committed
enable gpu
1 parent 6a390d9 commit e56f5d6

File tree

4 files changed

+65
-17
lines changed

4 files changed

+65
-17
lines changed

Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,13 @@ authors = ["JingYu Ning <[email protected]> and contributors"]
44
version = "0.1.0"
55

66
[deps]
7+
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
8+
CUDAKernels = "72cfdca4-0801-4ab0-bf6a-d52aa10adc57"
79
DataDeps = "124859b0-ceae-595e-8997-d05f6a7a8dfe"
810
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
911
Fetch = "bb354801-46f6-40b6-9c3d-d42d7a74c775"
1012
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
13+
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
1114
MAT = "23992714-dd62-5051-b70f-ba57cb901cac"
1215
Tullio = "bc48ee85-29a4-5162-ae0b-a64e1601d4bc"
1316
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

example/a.jl

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
using Zygote
2+
using Flux
3+
using CUDA
4+
using FFTW
5+
using Tullio
6+
7+
if has_cuda()
8+
@info "CUDA is on"
9+
device = gpu
10+
CUDA.allowscalar(true)
11+
else
12+
device = cpu
13+
end
14+
15+
function t(𝐱)
16+
@tullio 𝐱ᵀ[a, b, c] := 𝐱[b, a, c]
17+
18+
return 𝐱ᵀ
19+
end
20+
21+
m = Chain(
22+
Dense(2, 5),
23+
t,
24+
x->Zygote.hook(real, x),
25+
x->real(fft(x, 1)),
26+
t,
27+
Dense(5, 5),
28+
t,
29+
x->Zygote.hook(real, x),
30+
x->real(ifft(x, 1)),
31+
t,
32+
x->sum(x)
33+
) |> device
34+
35+
loss(x, y) = Flux.mse(m(x), y)
36+
37+
data = [(rand(Float32, 2, 100, 10), rand(Float32, 10))] |> device
38+
Flux.train!(loss, params(m), data, Flux.ADAM())

example/burgers.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
using NeuralOperators
22
using Flux
3-
# using CUDA
3+
using CUDA
44

5-
# if has_cuda()
6-
# @info "CUDA is on"
7-
# device = gpu
8-
# CUDA.allowscalar(false)
9-
# else
5+
if has_cuda()
6+
@info "CUDA is on"
7+
device = gpu
8+
CUDA.allowscalar(false)
9+
else
1010
device = cpu
11-
# end
11+
end
1212

1313
m = FourierNeuralOperator() |> device
1414
loss(𝐱, 𝐲) = sum(abs2, 𝐲 .- m(𝐱)) / size(𝐱)[end]
@@ -33,4 +33,4 @@ end
3333

3434
data = [(𝐱, 𝐲) for (𝐱, 𝐲) in loader_train] |> device
3535
opt = Flux.Optimiser(WeightDecay(1f-4), Flux.ADAM(1f-3))
36-
Flux.@epochs 500 @time(Flux.train!(loss, params(m), data, opt, cb=Flux.throttle(loss_test, 10)))
36+
Flux.@epochs 500 @time(Flux.train!(loss, params(m), data, opt))

src/fourier.jl

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
using CUDA, CUDAKernels, KernelAbstractions
2+
13
export
24
SpectralConv1d,
35
FourierOperator,
@@ -15,6 +17,9 @@ function c_glorot_uniform(dims...)
1517
return Flux.glorot_uniform(dims...) + Flux.glorot_uniform(dims...) * im
1618
end
1719

20+
t(𝐱) = @tullio 𝐱ᵀ[a, b, c] := 𝐱[b, a, c]
21+
ein_mul(𝐱₁, 𝐱₂) = @tullio 𝐲[m, o, b] := 𝐱₁[m, i, b] * 𝐱₂[o, i, m]
22+
1823
function SpectralConv1d(
1924
ch::Pair{<:Integer, <:Integer},
2025
modes::Integer,
@@ -27,25 +32,27 @@ function SpectralConv1d(
2732
weights = scale * init(out_chs, in_chs, modes)
2833

2934
return Chain(
35+
t,
3036
x -> Zygote.hook(real, x),
31-
SpectralConv1d(weights, in_chs, out_chs, modes, σ)
37+
SpectralConv1d(weights, in_chs, out_chs, modes, σ),
38+
t
3239
)
3340
end
3441

3542
Flux.@functor SpectralConv1d
3643

3744
function (m::SpectralConv1d)(𝐱::AbstractArray)
38-
𝐱_fft = fft(𝐱, 2) # [in_chs, x, batch]
39-
𝐱_selected = 𝐱_fft[:, 1:m.modes, :] # [in_chs, modes, batch]
45+
𝐱_fft = fft(𝐱, 1) # [x, in_chs, batch]
46+
𝐱_selected = 𝐱_fft[1:m.modes, :, :] # [modes, in_chs, batch]
4047

41-
# [out_chs, modes, batch] <- [in_chs, modes, batch] [out_chs, in_chs, modes]
42-
@tullio 𝐱_weighted[o, m, b] := 𝐱_selected[i, m, b] * m.weight[o, i, m]
48+
# [modes, out_chs, batch] <- [modes, in_chs, batch] [out_chs, in_chs, modes]
49+
𝐱_weighted = ein_mul(𝐱_selected, m.weight)
4350

44-
s = size(𝐱_weighted)
45-
d = size(𝐱, 2) - m.modes
46-
𝐱_padded = cat(𝐱_weighted, zeros(ComplexF32, s[1], d, s[3:end]...), dims=2)
51+
s = size(𝐱_weighted)[2:end]
52+
d = size(𝐱, 1) - m.modes
53+
𝐱_padded = cat(𝐱_weighted, zeros(ComplexF32, d, s...), dims=1)
4754

48-
𝐱_out = ifft(𝐱_padded, 2)
55+
𝐱_out = ifft(𝐱_padded, 1) # [x, out_chs, batch]
4956

5057
return m.σ.(real(𝐱_out))
5158
end

0 commit comments

Comments
 (0)