1- function M_perpt_z(buffer_real, buffer_complex1, buffer_complex2, op, dim, _size, z, fft_timer, mapping_timer; rdft:: Bool = false )
2- N = prod(_size)
1+ struct FFTOperator{R,C,OP,N,IM}
2+ buffer_real:: R
3+ buffer_complex1:: C
4+ buffer_complex2:: C
5+ op:: OP
6+ DFTdim:: Int64
7+ DFTsize:: NTuple{N,Int64}
8+ index_missing:: IM
9+ fft_timer:: Base.RefValue{Float64}
10+ mapping_timer:: Base.RefValue{Float64}
11+ rdft:: Bool
12+ end
13+
14+ function FFTOperator{VT}(nβ, DFTdim, DFTsize, index_missing, rdft) where VT
15+ T = eltype(VT)
16+ A_vec = VT(undef, nβ)
17+ A = reshape(A_vec, DFTsize)
18+ buffer_real = A
19+ if rdft == true
20+ op = plan_rfft(A)
21+ M1 = (DFTsize[1 ] ÷ 2 )
22+ if DFTdim == 1
23+ buffer_complex1 = Complex{T}. (A[1 : M1+ 1 ])
24+ elseif DFTdim == 2
25+ buffer_complex1 = Complex{T}. (A[1 : M1+ 1 ,:])
26+ else
27+ buffer_complex1 = Complex{T}. (A[1 : M1+ 1 ,:,:])
28+ end
29+ buffer_complex2 = buffer_complex1
30+ else
31+ op = plan_fft(A)
32+ buffer_complex1 = Complex{T}. (A)
33+ buffer_complex2 = copy(buffer_complex1)
34+ end
35+ fft_timer = Ref{Float64}(0.0 )
36+ mapping_timer = Ref{Float64}(0.0 )
37+
38+ return FFTOperator(buffer_real, buffer_complex1, buffer_complex2, op, DFTdim, DFTsize, index_missing, fft_timer, mapping_timer, rdft)
39+ end
40+
41+ function M_perpt_z(op_fft:: FFTOperator , z)
42+ return M_perpt_z(op_fft. buffer_real, op_fft. buffer_complex1, op_fft. buffer_complex2, op_fft. op,
43+ op_fft. DFTdim, op_fft. DFTsize, z, op_fft. fft_timer, op_fft. mapping_timer; rdft= op_fft. rdft)
44+ end
45+
46+ function M_perp_beta(op_fft:: FFTOperator , beta)
47+ return M_perp_beta(op_fft. buffer_real, op_fft. buffer_complex1, op_fft. buffer_complex2, op_fft. op,
48+ op_fft. DFTdim, op_fft. DFTsize, beta, op_fft. index_missing, op_fft. fft_timer,
49+ op_fft. mapping_timer; rdft= op_fft. rdft)
50+ end
51+
52+ function M_perpt_M_perp_vec(op_fft:: FFTOperator , vec)
53+ tmp = M_perp_beta(op_fft, vec)
54+ tmp = M_perpt_z(op_fft, tmp)
55+ return tmp
56+ end
57+
58+ function M_perpt_z(buffer_real, buffer_complex1, buffer_complex2, op, DFTdim, DFTsize, z, fft_timer, mapping_timer; rdft:: Bool = false )
59+ N = prod(DFTsize)
360
461 t1 = time_ns()
562 if rdft
@@ -14,18 +71,18 @@ function M_perpt_z(buffer_real, buffer_complex1, buffer_complex2, op, dim, _size
1471
1572 t3 = time_ns()
1673 beta = vec(buffer_real)
17- DFT_to_beta!(beta, dim, _size , temp; rdft)
74+ DFT_to_beta!(beta, DFTdim, DFTsize , temp; rdft)
1875 t4 = time_ns()
1976 mapping_timer[] = mapping_timer[] + (t4 - t3) / 1e9
2077 return beta
2178end
2279
23- function M_perp_beta(buffer_real, buffer_complex1, buffer_complex2, op, dim, _size , beta, idx_missing , fft_timer, mapping_timer; rdft:: Bool = false )
24- N = prod(_size )
80+ function M_perp_beta(buffer_real, buffer_complex1, buffer_complex2, op, DFTdim, DFTsize , beta, index_missing , fft_timer, mapping_timer; rdft:: Bool = false )
81+ N = prod(DFTsize )
2582
2683 t3 = time_ns()
2784 v = buffer_complex2
28- beta_to_DFT!(v, dim, _size , beta; rdft)
85+ beta_to_DFT!(v, DFTdim, DFTsize , beta; rdft)
2986 t4 = time_ns()
3087 mapping_timer[] = mapping_timer[] + (t4 - t3) / 1e9
3188
@@ -40,111 +97,111 @@ function M_perp_beta(buffer_real, buffer_complex1, buffer_complex2, op, dim, _si
4097 t2 = time_ns()
4198 fft_timer[] = fft_timer[] + (t2 - t1) / 1e9
4299
43- buffer_real[idx_missing ] .= 0
100+ buffer_real[index_missing ] .= 0
44101 return buffer_real
45102end
46103
47- function M_perpt_M_perp_vec(buffer_real, buffer_complex1, buffer_complex2, op, dim, _size , vec, idx_missing , fft_timer, mapping_timer; rdft:: Bool = false )
48- temp = M_perp_beta(buffer_real, buffer_complex1, buffer_complex2, op, dim, _size , vec, idx_missing , fft_timer, mapping_timer; rdft)
49- temp = M_perpt_z(buffer_real, buffer_complex1, buffer_complex2, op, dim, _size, temp , fft_timer, mapping_timer; rdft)
50- return temp
104+ function M_perpt_M_perp_vec(buffer_real, buffer_complex1, buffer_complex2, op, DFTdim, DFTsize , vec, index_missing , fft_timer, mapping_timer; rdft:: Bool = false )
105+ tmp = M_perp_beta(buffer_real, buffer_complex1, buffer_complex2, op, DFTdim, DFTsize , vec, index_missing , fft_timer, mapping_timer; rdft)
106+ tmp = M_perpt_z(buffer_real, buffer_complex1, buffer_complex2, op, DFTdim, DFTsize, tmp , fft_timer, mapping_timer; rdft)
107+ return tmp
51108end
52109
53110# mapping between DFT and real vector beta
54111
55112# mapping DFT to beta
56- # @param dim The dimension of the problem (dim = 1, 2, 3)
57- # @param size The size of each dimension of the problem
58- # (we only consider the cases when the sizes are even for all the dimenstions )
113+ # @param DFTdim The DFTdimension of the problem (DFTdim = 1, 2, 3)
114+ # @param size The size of each DFTdimension of the problem
115+ # (we only consider the cases when the sizes are even for all the DFTdimenstions )
59116# (size is a tuple, e.g. size = (10, 20, 30))
60117# @param v DFT
61118
62119# @details This fucnction maps DFT to beta
63120
64- # @return A 1-dimensional real vector beta whose length is the product of size
121+ # @return A 1-DFTdimensional real vector beta whose length is the product of size
65122# @example
66- # >dim = 2;
123+ # >DFTdim = 2;
67124# >size1 = (6, 8)
68125# >x = randn(6, 8)
69126# >v = fft(x)/sqrt(prod(size1))
70- # >beta = DFT_to_beta(dim , size1, v)
127+ # >beta = DFT_to_beta(DFTdim , size1, v)
71128
72- function DFT_to_beta!(beta, dim :: Int , size, v; rdft:: Bool = false )
73- if (dim == 1 )
129+ function DFT_to_beta!(beta, DFTdim :: Int , size, v; rdft:: Bool = false )
130+ if (DFTdim == 1 )
74131 DFT_to_beta_1d!(beta, v, size; rdft)
75- elseif (dim == 2 )
132+ elseif (DFTdim == 2 )
76133 DFT_to_beta_2d!(beta, v, size; rdft)
77134 else
78135 DFT_to_beta_3d!(beta, v, size; rdft)
79136 end
80137 return beta
81138end
82139
83- function DFT_to_beta(dim :: Int , size, v:: Array{ComplexF64} ; rdft:: Bool = false )
140+ function DFT_to_beta(DFTdim :: Int , size, v:: Array{ComplexF64} ; rdft:: Bool = false )
84141 N = prod(size)
85142 beta = Vector{Float64}(undef, N)
86- DFT_to_beta!(beta, dim , size, v; rdft)
143+ DFT_to_beta!(beta, DFTdim , size, v; rdft)
87144 return beta
88145end
89146
90- function DFT_to_beta(dim :: Int , size, v:: CuArray{ComplexF64} ; rdft:: Bool = false )
147+ function DFT_to_beta(DFTdim :: Int , size, v:: CuArray{ComplexF64} ; rdft:: Bool = false )
91148 N = prod(size)
92149 beta = CuVector{Float64}(undef, N)
93- DFT_to_beta!(beta, dim , size, v; rdft)
150+ DFT_to_beta!(beta, DFTdim , size, v; rdft)
94151 return beta
95152end
96153
97- function DFT_to_beta(dim :: Int , size, v:: ROCArray{ComplexF64} ; rdft:: Bool = false )
154+ function DFT_to_beta(DFTdim :: Int , size, v:: ROCArray{ComplexF64} ; rdft:: Bool = false )
98155 N = prod(size)
99156 beta = ROCVector{Float64}(undef, N)
100- DFT_to_beta!(beta, dim , size, v; rdft)
157+ DFT_to_beta!(beta, DFTdim , size, v; rdft)
101158 return beta
102159end
103160
104161# mapping beta to DFT
105- # @param dim The dimension of the problem (dim = 1, 2, 3)
106- # @param size The size of each dimension of the problem
107- # (we only consider the cases when the sizes are even for all the dimenstions )
162+ # @param DFTdim The DFTdimension of the problem (DFTdim = 1, 2, 3)
163+ # @param size The size of each DFTdimension of the problem
164+ # (we only consider the cases when the sizes are even for all the DFTdimenstions )
108165# (size is a tuple, e.g. size = (10, 20, 30))
109- # @param beta A 1-dimensional real vector with length equal to the product of size
166+ # @param beta A 1-DFTdimensional real vector with length equal to the product of size
110167
111168# @details This fucnction maps beta to DFT
112169
113170# @return DFT DFT shares the same size as param sizes
114171
115172# @example
116- # >dim = 2;
173+ # >DFTdim = 2;
117174# >size1 = (6, 8)
118175# >x = randn(6, 8)
119176# >v = fft(x)/sqrt(prod(size1))
120- # >beta = DFT_to_beta(dim , size1, v)
121- # >w = beta_to_DFT(dim , size1, beta) (w should be equal to v)
177+ # >beta = DFT_to_beta(DFTdim , size1, v)
178+ # >w = beta_to_DFT(DFTdim , size1, beta) (w should be equal to v)
122179
123- function beta_to_DFT!(v, dim :: Int , size, beta; rdft:: Bool = false )
124- if (dim == 1 )
180+ function beta_to_DFT!(v, DFTdim :: Int , size, beta; rdft:: Bool = false )
181+ if (DFTdim == 1 )
125182 v = beta_to_DFT_1d!(v, beta, size; rdft)
126- elseif (dim == 2 )
183+ elseif (DFTdim == 2 )
127184 v = beta_to_DFT_2d!(v, beta, size; rdft)
128185 else
129186 v = beta_to_DFT_3d!(v, beta, size; rdft)
130187 end
131188 return v
132189end
133190
134- function beta_to_DFT(dim :: Int , size, beta:: StridedVector{Float64} ; rdft:: Bool = false )
191+ function beta_to_DFT(DFTdim :: Int , size, beta:: StridedVector{Float64} ; rdft:: Bool = false )
135192 v = Array{ComplexF64}(undef, size)
136- beta_to_DFT!(v, dim , size, beta; rdft)
193+ beta_to_DFT!(v, DFTdim , size, beta; rdft)
137194 return v
138195end
139196
140- function beta_to_DFT(dim :: Int , size, beta:: StridedCuVector{Float64} ; rdft:: Bool = false )
197+ function beta_to_DFT(DFTdim :: Int , size, beta:: StridedCuVector{Float64} ; rdft:: Bool = false )
141198 v = CuArray{ComplexF64}(undef, size)
142- beta_to_DFT!(v, dim , size, beta; rdft)
199+ beta_to_DFT!(v, DFTdim , size, beta; rdft)
143200 return v
144201end
145202
146- function beta_to_DFT(dim :: Int , size, beta:: AMDGPU.StridedROCVector{Float64} ; rdft:: Bool = false )
203+ function beta_to_DFT(DFTdim :: Int , size, beta:: AMDGPU.StridedROCVector{Float64} ; rdft:: Bool = false )
147204 v = ROCArray{ComplexF64}(undef, size)
148- beta_to_DFT!(v, dim , size, beta; rdft)
205+ beta_to_DFT!(v, DFTdim , size, beta; rdft)
149206 return v
150207end
0 commit comments