Skip to content
This repository was archived by the owner on Sep 28, 2024. It is now read-only.

Commit 0bf7fda

Browse files
authored
Merge pull request #4 from ayushinav/main
Minor fixes
2 parents f42964c + 4daced6 commit 0bf7fda

File tree

5 files changed

+6
-6
lines changed

5 files changed

+6
-6
lines changed

src/fno.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ kernels, and two `Dense` layers to project data back to the scalar field of inte
1616
- `modes`: The modes to be preserved. A tuple of length `d`, where `d` is the dimension
1717
of data.
1818
- `σ`: Activation function for all layers in the model.
19-
- `permuted`: Whether the dim is permuted. If `permuted = Val(true)`, the layer accepts
19+
- `permuted`: Whether the dim is permuted. If `permuted = Val(false)`, the layer accepts
2020
data in the order of `(ch, x_1, ... , x_d , batch)`. Otherwise the order is
2121
`(x_1, ... , x_d, ch, batch)`.
2222

src/layers.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
## Keyword Arguments
1616
1717
- `init_weight`: Initial function to initialize parameters.
18-
- `permuted`: Whether the dim is permuted. If `permuted = Val(true)`, the layer accepts
18+
- `permuted`: Whether the dim is permuted. If `permuted = Val(false)`, the layer accepts
1919
data in the order of `(ch, x_1, ... , x_d , batch)`. Otherwise the order is
2020
`(x_1, ... , x_d, ch, batch)`.
2121
- `T`: Datatype of parameters.
@@ -86,7 +86,7 @@ SpectralConv(args...; kwargs...) = OperatorConv(args..., FourierTransform; kwarg
8686
- `ch`: A `Pair` of input and output channel size `ch_in => ch_out`, e.g. `64 => 64`.
8787
- `modes`: The modes to be preserved. A tuple of length `d`, where `d` is the dimension of
8888
data.
89-
- `::Type{TR}`: The traform to operate the transformation.
89+
- `::Type{TR}`: The transform to operate the transformation.
9090
9191
## Keyword Arguments
9292
@@ -237,7 +237,7 @@ end
237237
end
238238

239239
@inline function __apply_pattern_batched_mul(x, y)
240-
# Use permutedims to guarantee contiguous memory
240+
# Use permutedims to guarantee contiguous memory
241241
x_ = permutedims(x, (2, 3, 1)) # i x b x m
242242
res = batched_mul(y, x_) # o x b x m
243243
return permutedims(res, (3, 1, 2)) # m x o x b

test/fno.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ include("test_utils.jl")
1212
y_size=(1024, 1, 5), permuted=Val(true))]
1313

1414
@testset "$(length(setup.modes))D: permuted = $(setup.permuted)" for setup in setups
15-
fno = FourierNeuralOperator(; rng, setup.chs, setup.modes, setup.permuted)
15+
fno = FourierNeuralOperator(rng; setup.chs, setup.modes, setup.permuted)
1616

1717
x = rand(rng, Float32, setup.x_size...)
1818
y = rand(rng, Float32, setup.y_size...)
File renamed without changes.

test/test_utils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ default_loss_function(model, ps, x, y) = mean(abs2, y .- model(x, ps))
4040
train!(args...; kwargs...) = train!(default_loss_function, args...; kwargs...)
4141

4242
function train!(loss, model, ps, st, data; epochs = 10)
43-
m = Lux.Experimental.StatefulLuxLayer(model, ps, st)
43+
m = Lux.StatefulLuxLayer(model, ps, st)
4444

4545
l1 = loss(m, ps, first(data)...)
4646
st_opt = Optimisers.setup(Adam(0.01f0), ps)

0 commit comments

Comments
 (0)