Skip to content
This repository was archived by the owner on Sep 28, 2024. It is now read-only.

Commit 15c6b06

Browse files
committed
complete SparseKernel{N}
1 parent ec9051d commit 15c6b06

File tree

4 files changed

+79
-51
lines changed

4 files changed

+79
-51
lines changed

src/Transform/wavelet_transform.jl

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,44 @@
1-
struct SparseKernel{T,S}
2-
k::Int
3-
conv_blk::S
4-
out_weight::T
1+
export
2+
SparseKernel,
3+
SparseKernel1D,
4+
SparseKernel2D,
5+
SparseKernel3D
6+
7+
8+
struct SparseKernel{N,T,S}
9+
conv_blk::T
10+
out_weight::S
11+
end
12+
13+
function SparseKernel(filter::NTuple{N,T}, ch::Pair{S, S}; init=Flux.glorot_uniform) where {N,T,S}
14+
input_dim, emb_dim = ch
15+
conv = Conv(filter, input_dim=>emb_dim, relu; stride=1, pad=1, init=init)
16+
W_out = Dense(emb_dim, input_dim; init=init)
17+
return SparseKernel{N,typeof(conv),typeof(W_out)}(conv, W_out)
518
end
619

7-
function SparseKernel1d(k::Int, α, c::Int=1; init=Flux.glorot_uniform)
20+
function SparseKernel1D(k::Int, α, c::Int=1; init=Flux.glorot_uniform)
821
input_dim = c*k
922
emb_dim = 128
10-
conv = Conv((3,), input_dim=>emb_dim, relu; stride=1, pad=1, init=init)
11-
W_out = Dense(emb_dim, input_dim; init=init)
12-
return SparseKernel(k, conv, W_out)
23+
return SparseKernel((3, ), input_dim=>emb_dim; init=init)
1324
end
1425

15-
function SparseKernel2d(k::Int, α, c::Int=1; init=Flux.glorot_uniform)
26+
function SparseKernel2D(k::Int, α, c::Int=1; init=Flux.glorot_uniform)
1627
input_dim = c*k^2
1728
emb_dim = α*k^2
18-
conv = Conv((3, 3), input_dim=>emb_dim, relu; stride=1, pad=1, init=init)
19-
W_out = Dense(emb_dim, input_dim; init=init)
20-
return SparseKernel(k, conv, W_out)
29+
return SparseKernel((3, 3), input_dim=>emb_dim; init=init)
2130
end
2231

23-
function SparseKernel3d(k::Int, α, c::Int=1; init=Flux.glorot_uniform)
32+
function SparseKernel3D(k::Int, α, c::Int=1; init=Flux.glorot_uniform)
2433
input_dim = c*k^2
2534
emb_dim = α*k^2
2635
conv = Conv((3, 3, 3), emb_dim=>emb_dim, relu; stride=1, pad=1, init=init)
2736
W_out = Dense(emb_dim, input_dim; init=init)
28-
return SparseKernel(k, conv, W_out)
37+
return SparseKernel{3,typeof(conv),typeof(W_out)}(conv, W_out)
2938
end
3039

40+
Flux.@functor SparseKernel
41+
3142
function (l::SparseKernel)(X::AbstractArray)
3243
bch_sz, _, dims_r... = reverse(size(X))
3344
dims = reverse(dims_r)

test/Transform/Transform.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
@testset "Transform" begin
22
include("fourier_transform.jl")
33
include("chebyshev_transform.jl")
4+
include("wavelet_transform.jl")
45
end
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
@testset "SparseKernel" begin
2+
T = Float32
3+
k = 3
4+
batch_size = 32
5+
6+
@testset "1D SparseKernel" begin
7+
α = 4
8+
c = 1
9+
in_chs = 20
10+
X = rand(T, in_chs, c*k, batch_size)
11+
12+
l1 = SparseKernel1D(k, α, c)
13+
Y = l1(X)
14+
@test l1 isa SparseKernel{1}
15+
@test size(Y) == size(X)
16+
17+
gs = gradient(()->sum(l1(X)), Flux.params(l1))
18+
@test length(gs.grads) == 4
19+
end
20+
21+
@testset "2D SparseKernel" begin
22+
α = 4
23+
c = 3
24+
Nx = 5
25+
Ny = 7
26+
X = rand(T, Nx, Ny, c*k^2, batch_size)
27+
28+
l2 = SparseKernel2D(k, α, c)
29+
Y = l2(X)
30+
@test l2 isa SparseKernel{2}
31+
@test size(Y) == size(X)
32+
33+
gs = gradient(()->sum(l2(X)), Flux.params(l2))
34+
@test length(gs.grads) == 4
35+
end
36+
37+
@testset "3D SparseKernel" begin
38+
α = 4
39+
c = 3
40+
Nx = 5
41+
Ny = 7
42+
Nz = 13
43+
X = rand(T, Nx, Ny, Nz, α*k^2, batch_size)
44+
45+
l3 = SparseKernel3D(k, α, c)
46+
Y = l3(X)
47+
@test l3 isa SparseKernel{3}
48+
@test size(Y) == (Nx, Ny, Nz, c*k^2, batch_size)
49+
50+
gs = gradient(()->sum(l3(X)), Flux.params(l3))
51+
@test length(gs.grads) == 4
52+
end
53+
end

test/wavelet.jl

Lines changed: 0 additions & 37 deletions
This file was deleted.

0 commit comments

Comments
 (0)