Skip to content
This repository was archived by the owner on Sep 28, 2024. It is now read-only.

Commit 37e5d48

Browse files
committed
draft for MWT_CZ1d
1 parent 15c6b06 commit 37e5d48

File tree

2 files changed

+83
-86
lines changed

2 files changed

+83
-86
lines changed

src/Transform/polynomials.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,13 @@
1+
function get_filter(base::Symbol, k)
2+
if base == :legendre
3+
return legendre_filter(k)
4+
elseif base == :chebyshev
5+
return chebyshev_filter(k)
6+
else
7+
throw(ArgumentError("base must be one of :legendre or :chebyshev."))
8+
end
9+
end
10+
111
function legendre_ϕ_ψ(k)
212
# TODO: row-major -> column major
313
ϕ_coefs = zeros(k, k)

src/Transform/wavelet_transform.jl

Lines changed: 73 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -51,93 +51,80 @@ function (l::SparseKernel)(X::AbstractArray)
5151
end
5252

5353

54-
# struct MWT_CZ1d
55-
56-
# end
57-
58-
# function MWT_CZ1d(k::Int=3, c::Int=1; init=Flux.glorot_uniform)
59-
60-
# end
61-
62-
# class MWT_CZ1d(nn.Module):
63-
# def __init__(self,
64-
# k = 3, alpha = 5,
65-
# L = 0, c = 1,
66-
# base = 'legendre',
67-
# initializer = None,
68-
# **kwargs):
69-
# super(MWT_CZ1d, self).__init__()
70-
71-
# self.k = k
72-
# self.L = L
73-
# H0, H1, G0, G1, PHI0, PHI1 = get_filter(base, k)
74-
# H0r = H0@PHI0
75-
# G0r = G0@PHI0
76-
# H1r = H1@PHI1
77-
# G1r = G1@PHI1
78-
79-
# H0r[np.abs(H0r)<1e-8]=0
80-
# H1r[np.abs(H1r)<1e-8]=0
81-
# G0r[np.abs(G0r)<1e-8]=0
82-
# G1r[np.abs(G1r)<1e-8]=0
83-
84-
# self.A = sparseKernelFT1d(k, alpha, c)
85-
# self.B = sparseKernelFT1d(k, alpha, c)
86-
# self.C = sparseKernelFT1d(k, alpha, c)
87-
88-
# self.T0 = nn.Linear(k, k)
89-
90-
# self.register_buffer('ec_s', torch.Tensor(
91-
# np.concatenate((H0.T, H1.T), axis=0)))
92-
# self.register_buffer('ec_d', torch.Tensor(
93-
# np.concatenate((G0.T, G1.T), axis=0)))
94-
95-
# self.register_buffer('rc_e', torch.Tensor(
96-
# np.concatenate((H0r, G0r), axis=0)))
97-
# self.register_buffer('rc_o', torch.Tensor(
98-
# np.concatenate((H1r, G1r), axis=0)))
99-
100-
101-
# def forward(self, x):
102-
103-
# B, N, c, ich = x.shape # (B, N, k)
104-
# ns = math.floor(np.log2(N))
105-
106-
# Ud = torch.jit.annotate(List[Tensor], [])
107-
# Us = torch.jit.annotate(List[Tensor], [])
108-
# # decompose
109-
# for i in range(ns-self.L):
110-
# d, x = self.wavelet_transform(x)
111-
# Ud += [self.A(d) + self.B(x)]
112-
# Us += [self.C(d)]
113-
# x = self.T0(x) # coarsest scale transform
114-
115-
# # reconstruct
116-
# for i in range(ns-1-self.L,-1,-1):
117-
# x = x + Us[i]
118-
# x = torch.cat((x, Ud[i]), -1)
119-
# x = self.evenOdd(x)
120-
# return x
121-
122-
123-
# def wavelet_transform(self, x):
124-
# xa = torch.cat([x[:, ::2, :, :],
125-
# x[:, 1::2, :, :],
126-
# ], -1)
127-
# d = torch.matmul(xa, self.ec_d)
128-
# s = torch.matmul(xa, self.ec_s)
129-
# return d, s
130-
131-
132-
# def evenOdd(self, x):
133-
134-
# B, N, c, ich = x.shape # (B, N, c, k)
135-
# assert ich == 2*self.k
136-
# x_e = torch.matmul(x, self.rc_e)
137-
# x_o = torch.matmul(x, self.rc_o)
138-
54+
struct MWT_CZ1d{T,S,R,Q,P}
55+
k::Int
56+
L::Int
57+
A::T
58+
B::S
59+
C::R
60+
T0::Q
61+
ec_s::P
62+
ec_d::P
63+
rc_e::P
64+
rc_o::P
65+
end
66+
67+
function MWT_CZ1d(k::Int=3, α::Int=5, L::Int=0, c::Int=1; base::Symbol=:legendre, init=Flux.glorot_uniform)
68+
H0, H1, G0, G1, ϕ0, ϕ1 = get_filter(base, k)
69+
H0r = zero_out!(H0 * ϕ0)
70+
G0r = zero_out!(G0 * ϕ0)
71+
H1r = zero_out!(H1 * ϕ1)
72+
G1r = zero_out!(G1 * ϕ1)
73+
74+
dim = c*k
75+
A = SpectralConv(dim=>dim, (α,); init=init)
76+
B = SpectralConv(dim=>dim, (α,); init=init)
77+
C = SpectralConv(dim=>dim, (α,); init=init)
78+
T0 = Dense(k, k)
79+
80+
ec_s = vcat(H0', H1')
81+
ec_d = vcat(G0', G1')
82+
rc_e = vcat(H0r, G0r)
83+
rc_o = vcat(H1r, G1r)
84+
return MWT_CZ1d(k, L, A, B, C, T0, ec_s, ec_d, rc_e, rc_o)
85+
end
86+
87+
function wavelet_transform(l::MWT_CZ1d, X::AbstractArray{T,4}) where {T}
88+
N = size(X, 3)
89+
Xa = vcat(view(X, :, :, 1:2:N, :), view(X, :, :, 2:2:N, :))
90+
d = NNlib.batched_mul(Xa, l.ec_d)
91+
s = NNlib.batched_mul(Xa, l.ec_s)
92+
return d, s
93+
end
94+
95+
function even_odd(l::MWT_CZ1d, X::AbstractArray{T,4}) where {T}
96+
bch_sz, N, dims_r... = reverse(size(X))
97+
dims = reverse(dims_r)
98+
@assert dims[1] == 2*l.k
99+
Xₑ = NNlib.batched_mul(X, l.rc_e)
100+
Xₒ = NNlib.batched_mul(X, l.rc_o)
139101
# x = torch.zeros(B, N*2, c, self.k,
140102
# device = x.device)
141103
# x[..., ::2, :, :] = x_e
142104
# x[..., 1::2, :, :] = x_o
143-
# return x
105+
return X
106+
end
107+
108+
function (l::MWT_CZ1d)(X::T) where {T<:AbstractArray}
109+
bch_sz, N, dims_r... = reverse(size(X))
110+
ns = floor(log2(N))
111+
stop = ns - l.L
112+
113+
# decompose
114+
Ud = T[]
115+
Us = T[]
116+
for i in 1:stop
117+
d, X = wavelet_transform(l, X)
118+
push!(Ud, l.A(d)+l.B(d))
119+
push!(Us, l.C(d))
120+
end
121+
X = l.T0(X)
122+
123+
# reconstruct
124+
for i in stop:-1:1
125+
X += Us[i]
126+
X = vcat(X, Ud[i]) # x = torch.cat((x, Ud[i]), -1)
127+
X = even_odd(l, X)
128+
end
129+
return X
130+
end

0 commit comments

Comments
 (0)