Skip to content
This repository was archived by the owner on Sep 28, 2024. It is now read-only.

Commit c31cd64

Browse files
authored
Merge pull request #56 from SciML/loss
Implement l2 loss
2 parents 6014f52 + d8e84c6 commit c31cd64

File tree

5 files changed

+27
-7
lines changed

5 files changed

+27
-7
lines changed

example/FlowOverCircle/src/FlowOverCircle.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ function train()
3131
Dense(64, 1),
3232
) |> device
3333

34-
loss(𝐱, 𝐲) = sum(abs2, 𝐲 .- m(𝐱)) / size(𝐱)[end]
34+
loss(𝐱, 𝐲) = l₂loss(m(𝐱), 𝐲)
3535

3636
opt = Flux.Optimiser(WeightDecay(1f-4), Flux.ADAM(1f-3))
3737

src/NeuralOperators.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ module NeuralOperators
1414

1515
include("Transform/Transform.jl")
1616
include("operator_kernel.jl")
17+
include("loss.jl")
1718
include("model.jl")
1819
include("DeepONet.jl")
1920
include("subnets.jl")

src/loss.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
export l₂loss
2+
3+
function l₂loss(𝐲̂, 𝐲; agg=mean, grid_normalize=true)
4+
feature_dims = 2:(ndims(𝐲)-1)
5+
loss = agg(.√(sum(abs2, 𝐲̂-𝐲, dims=feature_dims)))
6+
7+
return grid_normalize ? loss/prod(feature_dims) : loss
8+
end

test/loss.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
@testset "loss" begin
2+
𝐲 = rand(1, 3, 3, 5)
3+
𝐲̂ = rand(1, 3, 3, 5)
4+
5+
feature_dims = 2:3
6+
loss = mean(.√(sum(abs2, 𝐲̂-𝐲, dims=feature_dims)))
7+
8+
@test l₂loss(𝐲̂, 𝐲) loss/prod(feature_dims)
9+
end

test/runtests.jl

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,21 @@ using Flux
44
using GeometricFlux
55
using Graphs
66
using Zygote
7+
using Statistics
78
using Test
89

910
CUDA.allowscalar(false)
1011

1112
cuda_tests = [
12-
"cuda",
13+
"cuda.jl",
1314
]
1415

1516
tests = [
16-
"Transform/Transform",
17-
"operator_kernel",
18-
"model",
19-
"deeponet",
17+
"Transform/Transform.jl",
18+
"operator_kernel.jl",
19+
"loss.jl",
20+
"model.jl",
21+
"deeponet.jl",
2022
]
2123

2224
if CUDA.functional()
@@ -27,7 +29,7 @@ end
2729

2830
@testset "NeuralOperators.jl" begin
2931
for t in tests
30-
include("$(t).jl")
32+
include(t)
3133
end
3234
end
3335

0 commit comments

Comments
 (0)