Skip to content
This repository was archived by the owner on Sep 28, 2024. It is now read-only.

Commit d49f1ac

Browse files
committed
add Chebyshev transform
1 parent ae5c22e commit d49f1ac

File tree

5 files changed

+48
-6
lines changed

5 files changed

+48
-6
lines changed

src/Transform/Transform.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,6 @@ export
1717
abstract type AbstractTransform end
1818

1919
include("fourier_transform.jl")
20+
include("chebyshev_transform.jl")
21+
22+
const truncate_modes = low_pass
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
export ChebyshevTransform
2+
3+
struct ChebyshevTransform{N, S}<:AbstractTransform
4+
modes::NTuple{N, S} # N == ndims(x)
5+
end
6+
7+
Base.ndims(::ChebyshevTransform{N}) where {N} = N
8+
9+
function transform(t::ChebyshevTransform{N}, 𝐱::AbstractArray) where {N}
10+
return FFTW.r2r(𝐱, FFTW.REDFT00, 1:N) # [size(x)..., in_chs, batch]
11+
end
12+
13+
function low_pass(t::ChebyshevTransform, 𝐱̂::AbstractArray)
14+
return view(𝐱̂, map(d->1:d, t.modes)..., :, :) # [ft.modes..., in_chs, batch]
15+
end
16+
17+
function inverse(t::ChebyshevTransform{N}, 𝐱̂::AbstractArray) where {N}
18+
return FFTW.r2r(
19+
𝐱̂ ./ (prod(2 .* (size(𝐱̂)[1:N] .- 1))),
20+
FFTW.REDFT00,
21+
1:N,
22+
) # [size(x)..., in_chs, batch]
23+
end

test/Transform/Transform.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,4 @@
1-
@testset "Transform" begin include("fourier_transform.jl") end
1+
@testset "Transform" begin
2+
include("fourier_transform.jl")
3+
include("chebyshev_transform.jl")
4+
end
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
@testset "Chebyshev transform" begin
2+
ch = 6
3+
batch = 7
4+
𝐱 = rand(30, 40, 50, ch, batch)
5+
6+
t = ChebyshevTransform((3, 4, 5))
7+
8+
@test size(transform(t, 𝐱)) == (30, 40, 50, ch, batch)
9+
@test size(truncate_modes(t, transform(t, 𝐱))) == (3, 4, 5, ch, batch)
10+
@test size(inverse(t, truncate_modes(t, transform(t, 𝐱)))) == (3, 4, 5, ch, batch)
11+
end
Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1-
@testset "fourier transform" begin
2-
𝐱 = rand(30, 40, 50, 6, 7) # where ch == 6 and batch == 7
1+
@testset "Fourier transform" begin
2+
ch = 6
3+
batch = 7
4+
𝐱 = rand(30, 40, 50, ch, batch)
35

46
ft = FourierTransform((3, 4, 5))
57

6-
@test size(transform(ft, 𝐱)) == (30, 40, 50, 6, 7)
7-
@test size(truncate_modes(ft, transform(ft, 𝐱))) == (3, 4, 5, 6, 7)
8-
@test size(inverse(ft, truncate_modes(ft, transform(ft, 𝐱)))) == (3, 4, 5, 6, 7)
8+
@test size(transform(ft, 𝐱)) == (30, 40, 50, ch, batch)
9+
@test size(truncate_modes(ft, transform(ft, 𝐱))) == (3, 4, 5, ch, batch)
10+
@test size(inverse(ft, truncate_modes(ft, transform(ft, 𝐱)))) == (3, 4, 5, ch, batch)
911
end

0 commit comments

Comments
 (0)