Skip to content
This repository was archived by the owner on Sep 28, 2024. It is now read-only.

Commit c254984

Browse files
committed
Rewrite eltype in transform
1 parent 0c7ac83 commit c254984

File tree

9 files changed

+46
-43
lines changed

9 files changed

+46
-43
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
2323
LuxNeuralOperatorsAMDGPUExt = "AMDGPU"
2424

2525
[compat]
26-
AMDGPU = "0.9.5"
26+
AMDGPU = "0.8.4, 0.9"
2727
Aqua = "0.8.7"
2828
ArgCheck = "2.3.0"
2929
ChainRulesCore = "1.24.0"

ext/LuxNeuralOperatorsAMDGPUExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,4 @@ using LuxNeuralOperators: LuxNeuralOperators
1111
return stack(*, eachslice(x; dims=3), eachslice(y; dims=3))
1212
end
1313

14-
end
14+
end

src/LuxNeuralOperators.jl

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,6 @@ const CRC = ChainRulesCore
1818

1919
@reexport using Lux
2020

21-
const True = Val(true)
22-
const False = Val(false)
23-
2421
include("utils.jl")
2522
include("transform.jl")
2623

src/fno.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ FourierNeuralOperator(
4949
"""
5050
function FourierNeuralOperator(
5151
σ=gelu; chs::Dims{C}=(2, 64, 64, 64, 64, 64, 128, 1), modes::Dims{M}=(16,),
52-
permuted::Val{perm}=False, kwargs...) where {C, M, perm}
52+
permuted::Val{perm}=Val(false), kwargs...) where {C, M, perm}
5353
@argcheck length(chs) 5
5454

5555
map₁ = chs[1] => chs[2]

src/functional.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@ end
1111
x_size = size(x_tr)
1212
x_flat = reshape(x_tr, :, x_size[N - 1], x_size[N])
1313

14-
x_flat_t = permutedims(x_flat, (2, 3, 1)) # i x b x m
15-
x_weighted = permutedims(weights x_flat_t, (3, 1, 2)) # m x o x b
14+
x_flat_t = permutedims(x_flat, (2, 3, 1)) # i x b x m
15+
x_weighted = permutedims(__batched_mul(weights, x_flat_t), (3, 1, 2)) # m x o x b
1616

1717
return reshape(x_weighted, x_size[1:(N - 2)]..., size(x_weighted)[2:3]...)
1818
end

src/layers.jl

Lines changed: 25 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,33 @@
11
"""
2-
OperatorConv(ch::Pair{<:Integer, <:Integer}, modes::NTuple{N, <:Integer}, ::Type{TR};
3-
init_weight = glorot_uniform, T::Type{TP} = ComplexF32,
4-
permuted::Val{P} = Val(false)) where {N, TR <: AbstractTransform, TP, P}
2+
OperatorConv(ch::Pair{<:Integer, <:Integer}, modes::NTuple{N, <:Integer},
3+
::Type{TR}; init_weight=glorot_uniform,
4+
permuted::Val{perm}=Val(false)) where {N, TR <: AbstractTransform, perm}
55
66
## Arguments
77
88
- `ch`: A `Pair` of input and output channel size `ch_in => ch_out`, e.g. `64 => 64`.
99
- `modes`: The modes to be preserved. A tuple of length `d`, where `d` is the dimension of
1010
data.
11-
- `::Type{TR}`: The traform to operate the transformation.
11+
- `::Type{TR}`: The transform to operate the transformation.
1212
1313
## Keyword Arguments
1414
1515
- `init_weight`: Initial function to initialize parameters.
1616
- `permuted`: Whether the dim is permuted. If `permuted = Val(false)`, the layer accepts
1717
data in the order of `(ch, x_1, ... , x_d, batch)`. Otherwise the order is
1818
`(x_1, ... , x_d, ch, batch)`.
19-
- `T`: Datatype of parameters.
2019
2120
## Example
2221
2322
```jldoctest
24-
julia> OperatorConv(2 => 5, (16,), FourierTransform)
23+
julia> OperatorConv(2 => 5, (16,), FourierTransform{ComplexF32})
2524
OperatorConv{FourierTransform}(2 => 5, (16,); permuted = false)() # 160 parameters
2625
27-
julia> OperatorConv(2 => 5, (16,), FourierTransform; permuted=Val(true))
26+
julia> OperatorConv(2 => 5, (16,), FourierTransform{ComplexF32}; permuted=Val(true))
2827
OperatorConv{FourierTransform}(2 => 5, (16,); permuted = true)() # 160 parameters
2928
```
3029
"""
31-
@concrete struct OperatorConv{elType, perm, T <: AbstractTransform} <: AbstractExplicitLayer
30+
@concrete struct OperatorConv{perm, T <: AbstractTransform} <: AbstractExplicitLayer
3231
in_chs::Int
3332
out_chs::Int
3433
prod_modes::Int
@@ -39,30 +38,30 @@ OperatorConv{FourierTransform}(2 => 5, (16,); permuted = true)() # 160 paramete
3938
name::String
4039
end
4140

42-
function LuxCore.initialparameters(
43-
rng::AbstractRNG, layer::OperatorConv{elType}) where {elType}
41+
function LuxCore.initialparameters(rng::AbstractRNG, layer::OperatorConv)
4442
in_chs, out_chs = layer.in_chs, layer.out_chs
45-
scale = real(one(elType)) / (in_chs * out_chs)
46-
weights = scale * layer.init_weight(rng, elType, out_chs, in_chs, layer.prod_modes)
47-
return (; weights,)
43+
scale = real(one(eltype(layer.tform))) / (in_chs * out_chs)
44+
return (;
45+
weights=scale * layer.init_weight(
46+
rng, eltype(layer.tform), out_chs, in_chs, layer.prod_modes))
4847
end
4948

5049
@inline function LuxCore.parameterlength(layer::OperatorConv)
5150
return layer.prod_modes * layer.in_chs * layer.out_chs
5251
end
5352

5453
function OperatorConv(ch::Pair{<:Integer, <:Integer}, modes::NTuple{N, <:Integer},
55-
::Type{TR}; init_weight=glorot_uniform, T::Type{TP}=ComplexF32,
56-
permuted::Val{perm}=Val(false)) where {N, TR <: AbstractTransform, TP, perm}
54+
::Type{TR}; init_weight=glorot_uniform,
55+
permuted::Val{perm}=Val(false)) where {N, TR <: AbstractTransform{<:Number}, perm}
5756
name = "OperatorConv{$(string(nameof(TR)))}($(ch[1]) => $(ch[2]), $modes; permuted = $perm)"
58-
return OperatorConv{TP, perm}(ch..., prod(modes), TR(modes), init_weight, name)
57+
return OperatorConv{perm}(ch..., prod(modes), TR(modes), init_weight, name)
5958
end
6059

61-
function (conv::OperatorConv{T, true})(x::AbstractArray{<:Real, M}, ps, st) where {T, M}
60+
function (conv::OperatorConv{true})(x::AbstractArray, ps, st)
6261
return operator_conv(x, conv.tform, ps.weights), st
6362
end
6463

65-
function (conv::OperatorConv{T, false})(x::AbstractArray{<:Real, M}, ps, st) where {T, M}
64+
function (conv::OperatorConv{false})(x::AbstractArray, ps, st)
6665
N = ndims(conv.tform)
6766
xᵀ = permutedims(x, (ntuple(i -> i + 1, N)..., 1, N + 2))
6867
yᵀ = operator_conv(xᵀ, conv.tform, ps.weights)
@@ -73,7 +72,7 @@ end
7372
"""
7473
SpectralConv(args...; kwargs...)
7574
76-
Construct a `OperatorConv` with `FourierTransform` as the transform. See
75+
Construct a `OperatorConv` with `FourierTransform{ComplexF32}` as the transform. See
7776
[`OperatorConv`](@ref) for the individual arguments.
7877
7978
## Example
@@ -86,7 +85,8 @@ julia> SpectralConv(2 => 5, (16,); permuted=Val(true))
8685
OperatorConv{FourierTransform}(2 => 5, (16,); permuted = true)() # 160 parameters
8786
```
8887
"""
89-
SpectralConv(args...; kwargs...) = OperatorConv(args..., FourierTransform; kwargs...)
88+
SpectralConv(args...; kwargs...) = OperatorConv(
89+
args..., FourierTransform{ComplexF32}; kwargs...)
9090

9191
"""
9292
OperatorKernel(ch::Pair{<:Integer, <:Integer}, modes::Dims{N}, transform::Type{TR},
@@ -106,14 +106,13 @@ SpectralConv(args...; kwargs...) = OperatorConv(args..., FourierTransform; kwarg
106106
- `permuted`: Whether the dim is permuted. If `permuted = Val(true)`, the layer accepts
107107
data in the order of `(ch, x_1, ... , x_d , batch)`. Otherwise the order is
108108
`(x_1, ... , x_d, ch, batch)`.
109-
- `T`: Datatype of parameters.
110109
111110
All the keyword arguments are passed to the [`OperatorConv`](@ref) constructor.
112111
113112
## Example
114113
115114
```jldoctest
116-
julia> OperatorKernel(2 => 5, (16,), FourierTransform)
115+
julia> OperatorKernel(2 => 5, (16,), FourierTransform{ComplexF64})
117116
@compact(
118117
l₁ = Dense(2 => 5), # 15 parameters
119118
l₂ = OperatorConv{FourierTransform}(2 => 5, (16,); permuted = false)(), # 160 parameters
@@ -125,7 +124,7 @@ julia> OperatorKernel(2 => 5, (16,), FourierTransform)
125124
end # Total: 175 parameters,
126125
# plus 1 states.
127126
128-
julia> OperatorKernel(2 => 5, (16,), FourierTransform; permuted=Val(true))
127+
julia> OperatorKernel(2 => 5, (16,), FourierTransform{ComplexF64}; permuted=Val(true))
129128
@compact(
130129
l₁ = Conv((1,), 2 => 5), # 15 parameters
131130
l₂ = OperatorConv{FourierTransform}(2 => 5, (16,); permuted = true)(), # 160 parameters
@@ -140,7 +139,7 @@ end # Total: 175 parameters,
140139
"""
141140
function OperatorKernel(ch::Pair{<:Integer, <:Integer}, modes::Dims{N}, transform::Type{TR},
142141
act::A=identity; allow_fast_activation::Bool=false, permuted::Val{perm}=Val(false),
143-
kwargs...) where {N, TR <: AbstractTransform, perm, A}
142+
kwargs...) where {N, TR <: AbstractTransform{<:Number}, perm, A}
144143
act = allow_fast_activation ? NNlib.fast_act(act) : act
145144
l₁ = perm ? Conv(map(_ -> 1, modes), ch) : Dense(ch)
146145
l₂ = OperatorConv(ch, modes, transform; permuted, kwargs...)
@@ -155,7 +154,7 @@ end
155154
"""
156155
SpectralKernel(args...; kwargs...)
157156
158-
Construct a `OperatorKernel` with `FourierTransform` as the transform. See
157+
Construct a `OperatorKernel` with `FourierTransform{ComplexF32}` as the transform. See
159158
[`OperatorKernel`](@ref) for the individual arguments.
160159
161160
## Example
@@ -188,5 +187,5 @@ end # Total: 175 parameters,
188187
"""
189188
function SpectralKernel(ch::Pair{<:Integer, <:Integer}, modes::Dims{N},
190189
act::A=identity; kwargs...) where {N, A}
191-
return OperatorKernel(ch, modes, FourierTransform, act; kwargs...)
190+
return OperatorKernel(ch, modes, FourierTransform{ComplexF32}, act; kwargs...)
192191
end

src/transform.jl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,16 @@
1010
- `inverse(<:AbstractTransform, x_transformed::AbstractArray)`: Apply the inverse
1111
transform to `x_transformed`
1212
"""
13-
abstract type AbstractTransform end
13+
abstract type AbstractTransform{T} end
14+
15+
@inline Base.eltype(::Type{<:AbstractTransform{T}}) where {T} = T
1416

1517
# Fourier Transform
16-
@concrete struct FourierTransform <: AbstractTransform
18+
@concrete struct FourierTransform{T} <: AbstractTransform{T}
1719
modes
1820
end
1921

20-
Base.ndims(T::FourierTransform) = length(T.modes)
21-
Base.eltype(::Type{FourierTransform}) = ComplexF32
22+
@inline Base.ndims(T::FourierTransform) = length(T.modes)
2223

2324
@inline transform(ft::FourierTransform, x::AbstractArray) = rfft(x, 1:ndims(ft))
2425

@@ -28,7 +29,7 @@ end
2829

2930
@inline truncate_modes(ft::FourierTransform, x_fft::AbstractArray) = low_pass(ft, x_fft)
3031

31-
function inverse(
32+
@inline function inverse(
3233
ft::FourierTransform, x_fft::AbstractArray{T, N}, M::NTuple{N, Int64}) where {T, N}
3334
return real(irfft(x_fft, first(M), 1:ndims(ft)))
3435
end

test/fno_tests.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,11 @@
2222
@test size(first(fno(x, ps, st))) == setup.y_size
2323

2424
data = [(x, y)]
25-
l2, l1 = train!(fno, ps, st, data; epochs=10)
26-
@test l2 < l1
25+
broken = mode == "AMDGPU"
26+
@test begin
27+
l2, l1 = train!(fno, ps, st, data; epochs=10)
28+
l2 < l1
29+
end broken=broken
2730
end
2831
end
2932
end

test/layers_tests.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,11 @@
3030
@jet m(x, ps, st)
3131

3232
data = [(x, aType(rand(rng, Float32, setup.y_size...)))]
33-
l2, l1 = train!(m, ps, st, data; epochs=10)
34-
@test l2 < l1
33+
broken = mode == "AMDGPU"
34+
@test begin
35+
l2, l1 = train!(m, ps, st, data; epochs=10)
36+
l2 < l1
37+
end broken=broken
3538
end
3639
end
3740
end

0 commit comments

Comments
 (0)