Skip to content

Commit f398816

Browse files
authored
Add a FFTOperator (#26)
1 parent 5009a8b commit f398816

File tree

6 files changed

+149
-109
lines changed

6 files changed

+149
-109
lines changed

.github/workflows/action.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ jobs:
3434
runs-on: cuda
3535
strategy:
3636
matrix:
37-
julia-version: ['lts', '1']
37+
julia-version: ['1']
3838
julia-arch: [x64]
3939

4040
steps:
@@ -54,7 +54,7 @@ jobs:
5454
runs-on: amdgpu
5555
strategy:
5656
matrix:
57-
julia-version: ['lts', '1']
57+
julia-version: ['1']
5858
julia-arch: [x64]
5959

6060
steps:

src/CompressedSensingIPM.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ using FFTW
77
using Krylov
88
using NLPModels
99

10-
export FFTNLPModel, FFTKKTSystem, FFTParameters
10+
export FFTNLPModel, FFTKKTSystem, FFTParameters, FFTOperator
1111

1212
include("fft_utils.jl")
1313
include("fft_model.jl")

src/fft_model.jl

Lines changed: 9 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -13,19 +13,13 @@ mutable struct FFTParameters{AT,N,IM}
1313
end
1414
end
1515

16-
mutable struct FFTNLPModel{T,VT,FFT,R,C,P} <: AbstractNLPModel{T,VT}
16+
mutable struct FFTNLPModel{T,VT,FFT,P} <: AbstractNLPModel{T,VT}
1717
meta::NLPModelMeta{T,VT}
18+
counters::Counters
1819
parameters::P
1920
::Int
20-
counters::Counters
21-
op::FFT
22-
buffer_real::R
23-
buffer_complex1::C
24-
buffer_complex2::C
21+
op_fft::FFT
2522
M_perpt_z0::VT
26-
rdft::Bool
27-
fft_timer::Base.RefValue{Float64}
28-
mapping_timer::Base.RefValue{Float64}
2923
krylov_solver::Symbol
3024
preconditioner::Bool
3125
end
@@ -37,6 +31,7 @@ function FFTNLPModel{VT}(parameters::FFTParameters;
3731
T = eltype(VT)
3832
DFTdim = parameters.DFTdim # problem size (1, 2, 3)
3933
DFTsize = parameters.DFTsize # problem dimension
34+
index_missing = parameters.index_missing
4035
= prod(DFTsize)
4136
nvar = 2 *
4237
ncon = 2 *
@@ -69,31 +64,10 @@ function FFTNLPModel{VT}(parameters::FFTParameters;
6964
)
7065

7166
# FFT operator
72-
A_vec = VT(undef, nβ)
73-
A = reshape(A_vec, DFTsize)
74-
buffer_real = A
75-
if rdft == true
76-
op = plan_rfft(A)
77-
M1 = (DFTsize[1] ÷ 2)
78-
if DFTdim == 1
79-
buffer_complex1 = Complex{T}.(A[1:M1+1])
80-
elseif DFTdim == 2
81-
buffer_complex1 = Complex{T}.(A[1:M1+1,:])
82-
else
83-
buffer_complex1 = Complex{T}.(A[1:M1+1,:,:])
84-
end
85-
buffer_complex2 = buffer_complex1
86-
else
87-
op = plan_fft(A)
88-
buffer_complex1 = Complex{T}.(A)
89-
buffer_complex2 = copy(buffer_complex1)
90-
end
91-
fft_timer = Ref{Float64}(0.0)
92-
mapping_timer = Ref{Float64}(0.0)
93-
tmp = M_perpt_z(buffer_real, buffer_complex1, buffer_complex2, op, DFTdim, DFTsize, parameters.z0, fft_timer, mapping_timer; rdft=rdft)
67+
op_fft = FFTOperator{VT}(nβ, DFTdim, DFTsize, index_missing, rdft)
68+
tmp = M_perpt_z(op_fft, parameters.z0)
9469
M_perpt_z0 = copy(tmp)
95-
return FFTNLPModel(meta, parameters, nβ, Counters(), op, buffer_real, buffer_complex1,
96-
buffer_complex2, M_perpt_z0, rdft, fft_timer, mapping_timer, krylov_solver, preconditioner)
70+
return FFTNLPModel(meta, Counters(), parameters, nβ, op_fft, M_perpt_z0, krylov_solver, preconditioner)
9771
end
9872

9973
function NLPModels.obj(nlp::FFTNLPModel, x::AbstractVector)
@@ -103,7 +77,7 @@ function NLPModels.obj(nlp::FFTNLPModel, x::AbstractVector)
10377
lambda = nlp.parameters.lambda
10478
index_missing = nlp.parameters.index_missing
10579

106-
fft_val = M_perp_beta(nlp.buffer_real, nlp.buffer_complex1, nlp.buffer_complex2, nlp.op, DFTdim, DFTsize, x, index_missing, nlp.fft_timer, nlp.mapping_timer; rdft=nlp.rdft)
80+
fft_val = M_perp_beta(nlp.op_fft, x)
10781
= nlp.
10882
beta = view(x, 1:nβ)
10983
c = view(x, nβ+1:2*nβ)
@@ -122,7 +96,7 @@ function NLPModels.grad!(nlp::FFTNLPModel, x::AbstractVector, g::AbstractVector)
12296
g_b = view(g, 1:nβ)
12397
g_c = view(g, nβ+1:2*nβ)
12498
beta = view(x, 1:nβ)
125-
res = M_perpt_M_perp_vec(nlp.buffer_real, nlp.buffer_complex1, nlp.buffer_complex2, nlp.op, DFTdim, DFTsize, beta, index_missing, nlp.fft_timer, nlp.mapping_timer; rdft=nlp.rdft)
99+
res = M_perpt_M_perp_vec(nlp.op_fft, beta)
126100
g_b .= res .- nlp.M_perpt_z0
127101
fill!(g_c, lambda)
128102
return g

src/fft_utils.jl

Lines changed: 99 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,62 @@
1-
function M_perpt_z(buffer_real, buffer_complex1, buffer_complex2, op, dim, _size, z, fft_timer, mapping_timer; rdft::Bool=false)
2-
N = prod(_size)
1+
struct FFTOperator{R,C,OP,N,IM}
2+
buffer_real::R
3+
buffer_complex1::C
4+
buffer_complex2::C
5+
op::OP
6+
DFTdim::Int64
7+
DFTsize::NTuple{N,Int64}
8+
index_missing::IM
9+
fft_timer::Base.RefValue{Float64}
10+
mapping_timer::Base.RefValue{Float64}
11+
rdft::Bool
12+
end
13+
14+
function FFTOperator{VT}(nβ, DFTdim, DFTsize, index_missing, rdft) where VT
15+
T = eltype(VT)
16+
A_vec = VT(undef, nβ)
17+
A = reshape(A_vec, DFTsize)
18+
buffer_real = A
19+
if rdft == true
20+
op = plan_rfft(A)
21+
M1 = (DFTsize[1] ÷ 2)
22+
if DFTdim == 1
23+
buffer_complex1 = Complex{T}.(A[1:M1+1])
24+
elseif DFTdim == 2
25+
buffer_complex1 = Complex{T}.(A[1:M1+1,:])
26+
else
27+
buffer_complex1 = Complex{T}.(A[1:M1+1,:,:])
28+
end
29+
buffer_complex2 = buffer_complex1
30+
else
31+
op = plan_fft(A)
32+
buffer_complex1 = Complex{T}.(A)
33+
buffer_complex2 = copy(buffer_complex1)
34+
end
35+
fft_timer = Ref{Float64}(0.0)
36+
mapping_timer = Ref{Float64}(0.0)
37+
38+
return FFTOperator(buffer_real, buffer_complex1, buffer_complex2, op, DFTdim, DFTsize, index_missing, fft_timer, mapping_timer, rdft)
39+
end
40+
41+
function M_perpt_z(op_fft::FFTOperator, z)
42+
return M_perpt_z(op_fft.buffer_real, op_fft.buffer_complex1, op_fft.buffer_complex2, op_fft.op,
43+
op_fft.DFTdim, op_fft.DFTsize, z, op_fft.fft_timer, op_fft.mapping_timer; rdft=op_fft.rdft)
44+
end
45+
46+
function M_perp_beta(op_fft::FFTOperator, beta)
47+
return M_perp_beta(op_fft.buffer_real, op_fft.buffer_complex1, op_fft.buffer_complex2, op_fft.op,
48+
op_fft.DFTdim, op_fft.DFTsize, beta, op_fft.index_missing, op_fft.fft_timer,
49+
op_fft.mapping_timer; rdft=op_fft.rdft)
50+
end
51+
52+
function M_perpt_M_perp_vec(op_fft::FFTOperator, vec)
53+
tmp = M_perp_beta(op_fft, vec)
54+
tmp = M_perpt_z(op_fft, tmp)
55+
return tmp
56+
end
57+
58+
function M_perpt_z(buffer_real, buffer_complex1, buffer_complex2, op, DFTdim, DFTsize, z, fft_timer, mapping_timer; rdft::Bool=false)
59+
N = prod(DFTsize)
360

461
t1 = time_ns()
562
if rdft
@@ -14,18 +71,18 @@ function M_perpt_z(buffer_real, buffer_complex1, buffer_complex2, op, dim, _size
1471

1572
t3 = time_ns()
1673
beta = vec(buffer_real)
17-
DFT_to_beta!(beta, dim, _size, temp; rdft)
74+
DFT_to_beta!(beta, DFTdim, DFTsize, temp; rdft)
1875
t4 = time_ns()
1976
mapping_timer[] = mapping_timer[] + (t4 - t3) / 1e9
2077
return beta
2178
end
2279

23-
function M_perp_beta(buffer_real, buffer_complex1, buffer_complex2, op, dim, _size, beta, idx_missing, fft_timer, mapping_timer; rdft::Bool=false)
24-
N = prod(_size)
80+
function M_perp_beta(buffer_real, buffer_complex1, buffer_complex2, op, DFTdim, DFTsize, beta, index_missing, fft_timer, mapping_timer; rdft::Bool=false)
81+
N = prod(DFTsize)
2582

2683
t3 = time_ns()
2784
v = buffer_complex2
28-
beta_to_DFT!(v, dim, _size, beta; rdft)
85+
beta_to_DFT!(v, DFTdim, DFTsize, beta; rdft)
2986
t4 = time_ns()
3087
mapping_timer[] = mapping_timer[] + (t4 - t3) / 1e9
3188

@@ -40,111 +97,111 @@ function M_perp_beta(buffer_real, buffer_complex1, buffer_complex2, op, dim, _si
4097
t2 = time_ns()
4198
fft_timer[] = fft_timer[] + (t2 - t1) / 1e9
4299

43-
buffer_real[idx_missing] .= 0
100+
buffer_real[index_missing] .= 0
44101
return buffer_real
45102
end
46103

47-
function M_perpt_M_perp_vec(buffer_real, buffer_complex1, buffer_complex2, op, dim, _size, vec, idx_missing, fft_timer, mapping_timer; rdft::Bool=false)
48-
temp = M_perp_beta(buffer_real, buffer_complex1, buffer_complex2, op, dim, _size, vec, idx_missing, fft_timer, mapping_timer; rdft)
49-
temp = M_perpt_z(buffer_real, buffer_complex1, buffer_complex2, op, dim, _size, temp, fft_timer, mapping_timer; rdft)
50-
return temp
104+
function M_perpt_M_perp_vec(buffer_real, buffer_complex1, buffer_complex2, op, DFTdim, DFTsize, vec, index_missing, fft_timer, mapping_timer; rdft::Bool=false)
105+
tmp = M_perp_beta(buffer_real, buffer_complex1, buffer_complex2, op, DFTdim, DFTsize, vec, index_missing, fft_timer, mapping_timer; rdft)
106+
tmp = M_perpt_z(buffer_real, buffer_complex1, buffer_complex2, op, DFTdim, DFTsize, tmp, fft_timer, mapping_timer; rdft)
107+
return tmp
51108
end
52109

53110
# mapping between DFT and real vector beta
54111

55112
# mapping DFT to beta
56-
# @param dim The dimension of the problem (dim = 1, 2, 3)
57-
# @param size The size of each dimension of the problem
58-
#(we only consider the cases when the sizes are even for all the dimenstions)
113+
# @param DFTdim The DFTdimension of the problem (DFTdim = 1, 2, 3)
114+
# @param size The size of each DFTdimension of the problem
115+
#(we only consider the cases when the sizes are even for all the DFTdimenstions)
59116
#(size is a tuple, e.g. size = (10, 20, 30))
60117
# @param v DFT
61118

62119
# @details This fucnction maps DFT to beta
63120

64-
# @return A 1-dimensional real vector beta whose length is the product of size
121+
# @return A 1-DFTdimensional real vector beta whose length is the product of size
65122
# @example
66-
# >dim = 2;
123+
# >DFTdim = 2;
67124
# >size1 = (6, 8)
68125
# >x = randn(6, 8)
69126
# >v = fft(x)/sqrt(prod(size1))
70-
# >beta = DFT_to_beta(dim, size1, v)
127+
# >beta = DFT_to_beta(DFTdim, size1, v)
71128

72-
function DFT_to_beta!(beta, dim::Int, size, v; rdft::Bool=false)
73-
if (dim == 1)
129+
function DFT_to_beta!(beta, DFTdim::Int, size, v; rdft::Bool=false)
130+
if (DFTdim == 1)
74131
DFT_to_beta_1d!(beta, v, size; rdft)
75-
elseif (dim == 2)
132+
elseif (DFTdim == 2)
76133
DFT_to_beta_2d!(beta, v, size; rdft)
77134
else
78135
DFT_to_beta_3d!(beta, v, size; rdft)
79136
end
80137
return beta
81138
end
82139

83-
function DFT_to_beta(dim::Int, size, v::Array{ComplexF64}; rdft::Bool=false)
140+
function DFT_to_beta(DFTdim::Int, size, v::Array{ComplexF64}; rdft::Bool=false)
84141
N = prod(size)
85142
beta = Vector{Float64}(undef, N)
86-
DFT_to_beta!(beta, dim, size, v; rdft)
143+
DFT_to_beta!(beta, DFTdim, size, v; rdft)
87144
return beta
88145
end
89146

90-
function DFT_to_beta(dim::Int, size, v::CuArray{ComplexF64}; rdft::Bool=false)
147+
function DFT_to_beta(DFTdim::Int, size, v::CuArray{ComplexF64}; rdft::Bool=false)
91148
N = prod(size)
92149
beta = CuVector{Float64}(undef, N)
93-
DFT_to_beta!(beta, dim, size, v; rdft)
150+
DFT_to_beta!(beta, DFTdim, size, v; rdft)
94151
return beta
95152
end
96153

97-
function DFT_to_beta(dim::Int, size, v::ROCArray{ComplexF64}; rdft::Bool=false)
154+
function DFT_to_beta(DFTdim::Int, size, v::ROCArray{ComplexF64}; rdft::Bool=false)
98155
N = prod(size)
99156
beta = ROCVector{Float64}(undef, N)
100-
DFT_to_beta!(beta, dim, size, v; rdft)
157+
DFT_to_beta!(beta, DFTdim, size, v; rdft)
101158
return beta
102159
end
103160

104161
# mapping beta to DFT
105-
# @param dim The dimension of the problem (dim = 1, 2, 3)
106-
# @param size The size of each dimension of the problem
107-
#(we only consider the cases when the sizes are even for all the dimenstions)
162+
# @param DFTdim The DFTdimension of the problem (DFTdim = 1, 2, 3)
163+
# @param size The size of each DFTdimension of the problem
164+
#(we only consider the cases when the sizes are even for all the DFTdimenstions)
108165
#(size is a tuple, e.g. size = (10, 20, 30))
109-
# @param beta A 1-dimensional real vector with length equal to the product of size
166+
# @param beta A 1-DFTdimensional real vector with length equal to the product of size
110167

111168
# @details This fucnction maps beta to DFT
112169

113170
# @return DFT DFT shares the same size as param sizes
114171

115172
# @example
116-
# >dim = 2;
173+
# >DFTdim = 2;
117174
# >size1 = (6, 8)
118175
# >x = randn(6, 8)
119176
# >v = fft(x)/sqrt(prod(size1))
120-
# >beta = DFT_to_beta(dim, size1, v)
121-
# >w = beta_to_DFT(dim, size1, beta) (w should be equal to v)
177+
# >beta = DFT_to_beta(DFTdim, size1, v)
178+
# >w = beta_to_DFT(DFTdim, size1, beta) (w should be equal to v)
122179

123-
function beta_to_DFT!(v, dim::Int, size, beta; rdft::Bool=false)
124-
if (dim == 1)
180+
function beta_to_DFT!(v, DFTdim::Int, size, beta; rdft::Bool=false)
181+
if (DFTdim == 1)
125182
v = beta_to_DFT_1d!(v, beta, size; rdft)
126-
elseif (dim == 2)
183+
elseif (DFTdim == 2)
127184
v = beta_to_DFT_2d!(v, beta, size; rdft)
128185
else
129186
v = beta_to_DFT_3d!(v, beta, size; rdft)
130187
end
131188
return v
132189
end
133190

134-
function beta_to_DFT(dim::Int, size, beta::StridedVector{Float64}; rdft::Bool=false)
191+
function beta_to_DFT(DFTdim::Int, size, beta::StridedVector{Float64}; rdft::Bool=false)
135192
v = Array{ComplexF64}(undef, size)
136-
beta_to_DFT!(v, dim, size, beta; rdft)
193+
beta_to_DFT!(v, DFTdim, size, beta; rdft)
137194
return v
138195
end
139196

140-
function beta_to_DFT(dim::Int, size, beta::StridedCuVector{Float64}; rdft::Bool=false)
197+
function beta_to_DFT(DFTdim::Int, size, beta::StridedCuVector{Float64}; rdft::Bool=false)
141198
v = CuArray{ComplexF64}(undef, size)
142-
beta_to_DFT!(v, dim, size, beta; rdft)
199+
beta_to_DFT!(v, DFTdim, size, beta; rdft)
143200
return v
144201
end
145202

146-
function beta_to_DFT(dim::Int, size, beta::AMDGPU.StridedROCVector{Float64}; rdft::Bool=false)
203+
function beta_to_DFT(DFTdim::Int, size, beta::AMDGPU.StridedROCVector{Float64}; rdft::Bool=false)
147204
v = ROCArray{ComplexF64}(undef, size)
148-
beta_to_DFT!(v, dim, size, beta; rdft)
205+
beta_to_DFT!(v, DFTdim, size, beta; rdft)
149206
return v
150207
end

src/kkt.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ function LinearAlgebra.mul!(y::AbstractVector, K::CondensedFFTKKTSystem, x::Abst
4343
xz = view(x, nβ+1:2*nβ)
4444

4545
# Evaluate Mᵀ M xβ
46-
Mβ .= M_perpt_M_perp_vec(nlp.buffer_real, nlp.buffer_complex1, nlp.buffer_complex2, nlp.op, DFTdim, DFTsize, xβ, index_missing, nlp.fft_timer, nlp.mapping_timer; K.nlp.rdft)
46+
Mβ .= M_perpt_M_perp_vec(nlp.op_fft, xβ)
4747

4848
yβ .= beta .*.+ alpha .* (Mβ .+ K.Λ1 .*.+ K.Λ2 .* xz)
4949
yz .= beta .* yz .+ alpha .* (K.Λ2 .*.+ K.Λ1 .* xz)
@@ -252,7 +252,7 @@ function MadNLP.mul!(y::VT, kkt::FFTKKTSystem, x::VT, alpha::Number, beta::Numbe
252252
xy2 = view(_x, 5*+1:6*nβ)
253253

254254
# Evaluate (MᵀM) * xβ
255-
Mβ .= M_perpt_M_perp_vec(kkt.nlp.buffer_real, kkt.nlp.buffer_complex1, kkt.nlp.buffer_complex2, kkt.nlp.op, DFTdim, DFTsize, xβ, index_missing, kkt.nlp.fft_timer, kkt.nlp.mapping_timer; kkt.nlp.rdft)
255+
Mβ .= M_perpt_M_perp_vec(kkt.nlp.op_fft, xβ)
256256
yβ .= beta .*.+ alpha .* (Mβ .- xy1 .+ xy2)
257257
yz .= beta .* yz .- alpha .* (xy1 .+ xy2)
258258
ys1 .= beta .* ys1 .- alpha .* xy1

0 commit comments

Comments
 (0)