Skip to content
This repository was archived by the owner on Sep 28, 2024. It is now read-only.

Commit f5a9017

Browse files
authored
Merge pull request #60 from SciML/comp
Update comp
2 parents c31cd64 + 5c1168e commit f5a9017

File tree

4 files changed

+15
-15
lines changed

4 files changed

+15
-15
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@ CUDA = "3.8"
2020
CUDAKernels = "0.3, 0.4"
2121
ChainRulesCore = "1.13"
2222
FFTW = "1.4"
23-
Flux = "0.12"
24-
GeometricFlux = "0.10"
23+
Flux = "0.13"
24+
GeometricFlux = "0.11"
2525
KernelAbstractions = "0.7, 0.8"
2626
Tullio = "0.3"
2727
Zygote = "0.6"

src/model.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ function FourierNeuralOperator(;
2727
OperatorKernel(ch[5]=>ch[6], modes, Transform),
2828
Dense(ch[6], ch[7], σ),
2929
Dense(ch[7], ch[8]),
30-
flatten
30+
Flux.flatten
3131
)
3232
end
3333

@@ -48,7 +48,7 @@ function MarkovNeuralOperator(;
4848
σ=gelu
4949
)
5050
Transform = FourierTransform
51-
51+
5252
return Chain(
5353
Dense(ch[1], ch[2]),
5454
OperatorKernel(ch[2]=>ch[3], modes, Transform, σ),

test/model.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
loss(𝐱, 𝐲) = sum(abs2, 𝐲 .- m(𝐱)) / size(𝐱)[end]
88
data = [(𝐱, 𝐲)]
9-
Flux.train!(loss, params(m), data, Flux.ADAM())
9+
Flux.train!(loss, Flux.params(m), data, Flux.ADAM())
1010
end
1111

1212
@testset "MarkovNeuralOperator" begin
@@ -17,5 +17,5 @@ end
1717

1818
loss(𝐱, 𝐲) = sum(abs2, 𝐲 .- m(𝐱)) / size(𝐱)[end]
1919
data = [(𝐱, 𝐲)]
20-
Flux.train!(loss, params(m), data, Flux.ADAM())
20+
Flux.train!(loss, Flux.params(m), data, Flux.ADAM())
2121
end

test/operator_kernel.jl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
loss(x, y) = Flux.mse(m(x), y)
1616
data = [(𝐱, rand(Float32, 128, 1024, 5))]
17-
Flux.train!(loss, params(m), data, Flux.ADAM())
17+
Flux.train!(loss, Flux.params(m), data, Flux.ADAM())
1818
end
1919

2020
@testset "permuted 1D OperatorConv" begin
@@ -34,7 +34,7 @@ end
3434

3535
loss(x, y) = Flux.mse(m(x), y)
3636
data = [(𝐱, rand(Float32, 1024, 128, 5))]
37-
Flux.train!(loss, params(m), data, Flux.ADAM())
37+
Flux.train!(loss, Flux.params(m), data, Flux.ADAM())
3838
end
3939

4040
@testset "1D OperatorKernel" begin
@@ -52,7 +52,7 @@ end
5252

5353
loss(x, y) = Flux.mse(m(x), y)
5454
data = [(𝐱, rand(Float32, 128, 1024, 5))]
55-
Flux.train!(loss, params(m), data, Flux.ADAM())
55+
Flux.train!(loss, Flux.params(m), data, Flux.ADAM())
5656
end
5757

5858
@testset "permuted 1D OperatorKernel" begin
@@ -71,7 +71,7 @@ end
7171

7272
loss(x, y) = Flux.mse(m(x), y)
7373
data = [(𝐱, rand(Float32, 1024, 128, 5))]
74-
Flux.train!(loss, params(m), data, Flux.ADAM())
74+
Flux.train!(loss, Flux.params(m), data, Flux.ADAM())
7575
end
7676

7777
@testset "2D OperatorConv" begin
@@ -89,7 +89,7 @@ end
8989

9090
loss(x, y) = Flux.mse(m(x), y)
9191
data = [(𝐱, rand(Float32, 64, 22, 22, 5))]
92-
Flux.train!(loss, params(m), data, Flux.ADAM())
92+
Flux.train!(loss, Flux.params(m), data, Flux.ADAM())
9393
end
9494

9595
@testset "permuted 2D OperatorConv" begin
@@ -108,7 +108,7 @@ end
108108

109109
loss(x, y) = Flux.mse(m(x), y)
110110
data = [(𝐱, rand(Float32, 22, 22, 64, 5))]
111-
Flux.train!(loss, params(m), data, Flux.ADAM())
111+
Flux.train!(loss, Flux.params(m), data, Flux.ADAM())
112112
end
113113

114114
@testset "2D OperatorKernel" begin
@@ -125,7 +125,7 @@ end
125125

126126
loss(x, y) = Flux.mse(m(x), y)
127127
data = [(𝐱, rand(Float32, 64, 22, 22, 5))]
128-
Flux.train!(loss, params(m), data, Flux.ADAM())
128+
Flux.train!(loss, Flux.params(m), data, Flux.ADAM())
129129
end
130130

131131
@testset "permuted 2D OperatorKernel" begin
@@ -143,7 +143,7 @@ end
143143

144144
loss(x, y) = Flux.mse(m(x), y)
145145
data = [(𝐱, rand(Float32, 22, 22, 64, 5))]
146-
Flux.train!(loss, params(m), data, Flux.ADAM())
146+
Flux.train!(loss, Flux.params(m), data, Flux.ADAM())
147147
end
148148

149149
@testset "SpectralConv" begin
@@ -165,7 +165,7 @@ end
165165
graph = grid([10, 10])
166166
𝐱 = rand(Float32, channel, N, batch_size)
167167
l = WithGraph(FeaturedGraph(graph), GraphKernel(κ, channel))
168-
@test repr(l.layer) == "GraphKernel(Dense(64, 32, relu), channel=32)"
168+
@test repr(l.layer) == "GraphKernel(Dense(64 => 32, relu), channel=32)"
169169
@test size(l(𝐱)) == (channel, N, batch_size)
170170

171171
g = Zygote.gradient(() -> sum(l(𝐱)), Flux.params(l))

0 commit comments

Comments
 (0)