@@ -51,93 +51,80 @@ function (l::SparseKernel)(X::AbstractArray)
51
51
end
52
52
53
53
54
- # struct MWT_CZ1d
55
-
56
- # end
57
-
58
- # function MWT_CZ1d(k::Int=3, c::Int=1; init=Flux.glorot_uniform)
59
-
60
- # end
61
-
62
- # class MWT_CZ1d(nn.Module):
63
- # def __init__(self,
64
- # k = 3, alpha = 5,
65
- # L = 0, c = 1,
66
- # base = 'legendre',
67
- # initializer = None,
68
- # **kwargs):
69
- # super(MWT_CZ1d, self).__init__()
70
-
71
- # self.k = k
72
- # self.L = L
73
- # H0, H1, G0, G1, PHI0, PHI1 = get_filter(base, k)
74
- # H0r = H0@PHI0
75
- # G0r = G0@PHI0
76
- # H1r = H1@PHI1
77
- # G1r = G1@PHI1
78
-
79
- # H0r[np.abs(H0r)<1e-8]=0
80
- # H1r[np.abs(H1r)<1e-8]=0
81
- # G0r[np.abs(G0r)<1e-8]=0
82
- # G1r[np.abs(G1r)<1e-8]=0
83
-
84
- # self.A = sparseKernelFT1d(k, alpha, c)
85
- # self.B = sparseKernelFT1d(k, alpha, c)
86
- # self.C = sparseKernelFT1d(k, alpha, c)
87
-
88
- # self.T0 = nn.Linear(k, k)
89
-
90
- # self.register_buffer('ec_s', torch.Tensor(
91
- # np.concatenate((H0.T, H1.T), axis=0)))
92
- # self.register_buffer('ec_d', torch.Tensor(
93
- # np.concatenate((G0.T, G1.T), axis=0)))
94
-
95
- # self.register_buffer('rc_e', torch.Tensor(
96
- # np.concatenate((H0r, G0r), axis=0)))
97
- # self.register_buffer('rc_o', torch.Tensor(
98
- # np.concatenate((H1r, G1r), axis=0)))
99
-
100
-
101
- # def forward(self, x):
102
-
103
- # B, N, c, ich = x.shape # (B, N, k)
104
- # ns = math.floor(np.log2(N))
105
-
106
- # Ud = torch.jit.annotate(List[Tensor], [])
107
- # Us = torch.jit.annotate(List[Tensor], [])
108
- # # decompose
109
- # for i in range(ns-self.L):
110
- # d, x = self.wavelet_transform(x)
111
- # Ud += [self.A(d) + self.B(x)]
112
- # Us += [self.C(d)]
113
- # x = self.T0(x) # coarsest scale transform
114
-
115
- # # reconstruct
116
- # for i in range(ns-1-self.L,-1,-1):
117
- # x = x + Us[i]
118
- # x = torch.cat((x, Ud[i]), -1)
119
- # x = self.evenOdd(x)
120
- # return x
121
-
122
-
123
- # def wavelet_transform(self, x):
124
- # xa = torch.cat([x[:, ::2, :, :],
125
- # x[:, 1::2, :, :],
126
- # ], -1)
127
- # d = torch.matmul(xa, self.ec_d)
128
- # s = torch.matmul(xa, self.ec_s)
129
- # return d, s
130
-
131
-
132
- # def evenOdd(self, x):
133
-
134
- # B, N, c, ich = x.shape # (B, N, c, k)
135
- # assert ich == 2*self.k
136
- # x_e = torch.matmul(x, self.rc_e)
137
- # x_o = torch.matmul(x, self.rc_o)
138
-
54
+ struct MWT_CZ1d{T,S,R,Q,P}
55
+ k:: Int
56
+ L:: Int
57
+ A:: T
58
+ B:: S
59
+ C:: R
60
+ T0:: Q
61
+ ec_s:: P
62
+ ec_d:: P
63
+ rc_e:: P
64
+ rc_o:: P
65
+ end
66
+
67
+ function MWT_CZ1d (k:: Int = 3 , α:: Int = 5 , L:: Int = 0 , c:: Int = 1 ; base:: Symbol = :legendre , init= Flux. glorot_uniform)
68
+ H0, H1, G0, G1, ϕ0, ϕ1 = get_filter (base, k)
69
+ H0r = zero_out! (H0 * ϕ0)
70
+ G0r = zero_out! (G0 * ϕ0)
71
+ H1r = zero_out! (H1 * ϕ1)
72
+ G1r = zero_out! (G1 * ϕ1)
73
+
74
+ dim = c* k
75
+ A = SpectralConv (dim=> dim, (α,); init= init)
76
+ B = SpectralConv (dim=> dim, (α,); init= init)
77
+ C = SpectralConv (dim=> dim, (α,); init= init)
78
+ T0 = Dense (k, k)
79
+
80
+ ec_s = vcat (H0' , H1' )
81
+ ec_d = vcat (G0' , G1' )
82
+ rc_e = vcat (H0r, G0r)
83
+ rc_o = vcat (H1r, G1r)
84
+ return MWT_CZ1d (k, L, A, B, C, T0, ec_s, ec_d, rc_e, rc_o)
85
+ end
86
+
87
+ function wavelet_transform (l:: MWT_CZ1d , X:: AbstractArray{T,4} ) where {T}
88
+ N = size (X, 3 )
89
+ Xa = vcat (view (X, :, :, 1 : 2 : N, :), view (X, :, :, 2 : 2 : N, :))
90
+ d = NNlib. batched_mul (Xa, l. ec_d)
91
+ s = NNlib. batched_mul (Xa, l. ec_s)
92
+ return d, s
93
+ end
94
+
95
+ function even_odd (l:: MWT_CZ1d , X:: AbstractArray{T,4} ) where {T}
96
+ bch_sz, N, dims_r... = reverse (size (X))
97
+ dims = reverse (dims_r)
98
+ @assert dims[1 ] == 2 * l. k
99
+ Xₑ = NNlib. batched_mul (X, l. rc_e)
100
+ Xₒ = NNlib. batched_mul (X, l. rc_o)
139
101
# x = torch.zeros(B, N*2, c, self.k,
140
102
# device = x.device)
141
103
# x[..., ::2, :, :] = x_e
142
104
# x[..., 1::2, :, :] = x_o
143
- # return x
105
+ return X
106
+ end
107
+
108
+ function (l:: MWT_CZ1d )(X:: T ) where {T<: AbstractArray }
109
+ bch_sz, N, dims_r... = reverse (size (X))
110
+ ns = floor (log2 (N))
111
+ stop = ns - l. L
112
+
113
+ # decompose
114
+ Ud = T[]
115
+ Us = T[]
116
+ for i in 1 : stop
117
+ d, X = wavelet_transform (l, X)
118
+ push! (Ud, l. A (d)+ l. B (d))
119
+ push! (Us, l. C (d))
120
+ end
121
+ X = l. T0 (X)
122
+
123
+ # reconstruct
124
+ for i in stop: - 1 : 1
125
+ X += Us[i]
126
+ X = vcat (X, Ud[i]) # x = torch.cat((x, Ud[i]), -1)
127
+ X = even_odd (l, X)
128
+ end
129
+ return X
130
+ end
0 commit comments