Skip to content

Commit c745569

Browse files
authored
Merge pull request #629 from JuliaParallel/jps/fft-only
DArray: Add FFT implementation for 1D/2D/3D
2 parents 9c7be3f + 1b57eb7 commit c745569

File tree

6 files changed

+470
-0
lines changed

6 files changed

+470
-0
lines changed

Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
3131

3232
[weakdeps]
3333
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
34+
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
3435
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
3536
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
3637
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
@@ -44,6 +45,7 @@ Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
4445
PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d"
4546

4647
[extensions]
48+
AbstractFFTsExt = "AbstractFFTs"
4749
CUDAExt = "CUDA"
4850
DistributionsExt = "Distributions"
4951
GraphVizExt = "GraphViz"
@@ -58,6 +60,7 @@ ROCExt = "AMDGPU"
5860

5961
[compat]
6062
AMDGPU = "1"
63+
AbstractFFTs = "1.5.0"
6164
Adapt = "4"
6265
CUDA = "3, 4, 5"
6366
Colors = "0.12, 0.13"

docs/src/darray.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -447,3 +447,7 @@ From `LinearAlgebra`:
447447
- `mul!` (In-place Matrix-Matrix multiply)
448448
- `cholesky`/`cholesky!` (In-place/Out-of-place Cholesky factorization)
449449
- `lu`/`lu!` (In-place/Out-of-place LU factorization (`NoPivot` only))
450+
451+
From `AbstractFFTs`:
452+
- `fft`/`fft!`
453+
- `ifft`/`ifft!`

ext/AbstractFFTsExt.jl

Lines changed: 320 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,320 @@
1+
module AbstractFFTsExt
2+
3+
import Dagger
4+
import Dagger: DArray, DVector, DMatrix, Blocks, AutoBlocks, InOut
5+
import AbstractFFTs
6+
import LinearAlgebra
7+
8+
abstract type Decomposition end
9+
struct Pencil <: Decomposition end
10+
struct Slab <: Decomposition end
11+
12+
# High-level interface
13+
14+
## TODO: Add optimized 1D algorithm
15+
16+
## 1D out-of-place
17+
AbstractFFTs.fft(A::DVector) = DVector(AbstractFFTs.fft(collect(A)))
18+
AbstractFFTs.ifft(A::DVector) = DVector(AbstractFFTs.ifft(collect(A)))
19+
20+
## 1D in-place
21+
function AbstractFFTs.fft!(DA::DVector{T}) where T
22+
A = Vector{T}(undef, length(DA))
23+
copyto!(A, DA)
24+
AbstractFFTs.fft!(A)
25+
copyto!(DA, A)
26+
return DA
27+
end
28+
function AbstractFFTs.ifft!(DA::DVector{T}) where T
29+
A = Vector{T}(undef, length(DA))
30+
copyto!(A, DA)
31+
AbstractFFTs.ifft!(A)
32+
copyto!(DA, A)
33+
return DA
34+
end
35+
36+
## 2D out-of-place
37+
function AbstractFFTs.fft(DA::DMatrix, dims=(1, 2))
38+
DB = similar(DA)
39+
_fft!(DB, DA, dims)
40+
return DB
41+
end
42+
function AbstractFFTs.ifft(DA::DMatrix, dims=(1, 2))
43+
DB = similar(DA)
44+
_ifft!(DB, DA, dims)
45+
return DB
46+
end
47+
48+
## 2D in-place
49+
function AbstractFFTs.fft!(DA::DMatrix{T}, dims=(1, 2)) where T
50+
_fft!(DA, DA, dims)
51+
return DA
52+
end
53+
function AbstractFFTs.ifft!(DA::DMatrix{T}, dims=(1, 2)) where T
54+
_ifft!(DA, DA, dims)
55+
return DA
56+
end
57+
58+
## 3D out-of-place
59+
function AbstractFFTs.fft(DA::DArray{T,3}, dims=(1, 2, 3); decomp::Union{Decomposition,Symbol}=Pencil()) where T
60+
DB = similar(DA)
61+
_decomp = _to_decomp(decomp)
62+
_fft!(DB, DA, dims; decomp=_decomp)
63+
return DB
64+
end
65+
function AbstractFFTs.ifft(DA::DArray{T,3}, dims=(1, 2, 3); decomp::Union{Decomposition,Symbol}=Pencil()) where T
66+
DB = similar(DA)
67+
_decomp = _to_decomp(decomp)
68+
_ifft!(DB, DA, dims; decomp=_decomp)
69+
return DB
70+
end
71+
72+
## 3D in-place
73+
function AbstractFFTs.fft!(DA::DArray{T,3}, dims=(1, 2, 3); decomp::Union{Decomposition,Symbol}=Pencil()) where T
74+
_decomp = _to_decomp(decomp)
75+
_fft!(DA, DA, dims; decomp=_decomp)
76+
return DA
77+
end
78+
function AbstractFFTs.ifft!(DA::DArray{T,3}, dims=(1, 2, 3); decomp::Union{Decomposition,Symbol}=Pencil()) where T
79+
_decomp = _to_decomp(decomp)
80+
_ifft!(DA, DA, dims; decomp=_decomp)
81+
return DA
82+
end
83+
84+
# Mid-level interface
85+
86+
_to_decomp(decomp::Decomposition) = decomp
87+
function _to_decomp(decomp::Symbol)
88+
if decomp == :pencil
89+
return Pencil()
90+
elseif decomp == :slab
91+
return Slab()
92+
else
93+
throw(ArgumentError("Unknown decomposition type: $decomp\nSupported types: :pencil, :slab"))
94+
end
95+
end
96+
97+
## 2D
98+
function _fft!(output::DMatrix{T}, input::DMatrix{T}, dims=(1, 2)) where T
99+
N = size(input, 1)
100+
np = length(Dagger.compatible_processors())
101+
A = zeros(Blocks(N, div(N, np)), T, size(input))
102+
copyto!(A, input)
103+
B = zeros(Blocks(div(N, np), N), T, size(input))
104+
__fft!(A, B, dims)
105+
copyto!(output, B)
106+
return output
107+
end
108+
function _ifft!(output::DMatrix{T}, input::DMatrix{T}, dims=(1, 2)) where T
109+
N = size(input, 1)
110+
np = length(Dagger.compatible_processors())
111+
A = zeros(Blocks(N, div(N, np)), T, size(input))
112+
copyto!(A, input)
113+
B = zeros(Blocks(div(N, np), N), T, size(input))
114+
__ifft!(A, B, dims)
115+
copyto!(output, B)
116+
return output
117+
end
118+
119+
## 3D
120+
function _fft!(output::DArray{T,3}, input::DArray{T,3}, dims=(1, 2, 3); decomp::Decomposition=Pencil()) where T
121+
N = size(input, 1)
122+
np = length(Dagger.compatible_processors())
123+
if decomp isa Pencil
124+
A = zeros(Blocks(N, div(N, np), div(N, np)), T, size(input))
125+
B = zeros(Blocks(div(N, np), N, div(N, np)), T, size(input))
126+
C = zeros(Blocks(div(N, np), div(N, np), N), T, size(input))
127+
copyto!(A, input)
128+
__fft!(decomp, A, B, C, dims)
129+
copyto!(output, C)
130+
return output
131+
elseif decomp isa Slab
132+
A = zeros(Blocks(N, N, div(N, np)), T, size(input))
133+
B = zeros(Blocks(div(N, np), div(N, np), N), T, size(input))
134+
copyto!(A, input)
135+
__fft!(decomp, A, B, dims)
136+
copyto!(output, B)
137+
return output
138+
else
139+
throw(ArgumentError("Unknown decomposition type: $decomp"))
140+
end
141+
end
142+
function _ifft!(output::DArray{T,3}, input::DArray{T,3}, dims=(1, 2, 3); decomp::Decomposition=Pencil()) where T
143+
N = size(input, 1)
144+
np = length(Dagger.compatible_processors())
145+
if decomp isa Pencil
146+
A = zeros(Blocks(div(N, np), div(N, np), N), T, size(input))
147+
B = zeros(Blocks(div(N, np), N, div(N, np)), T, size(input))
148+
C = zeros(Blocks(N, div(N, np), div(N, np)), T, size(input))
149+
copyto!(A, input)
150+
__ifft!(decomp, A, B, C, dims)
151+
copyto!(output, C)
152+
return output
153+
elseif decomp isa Slab
154+
A = zeros(Blocks(div(N, np), div(N, np), N), T, size(input))
155+
B = zeros(Blocks(N, N, div(N, np)), T, size(input))
156+
copyto!(A, input)
157+
__ifft!(decomp, A, B, dims)
158+
copyto!(output, B)
159+
return output
160+
end
161+
end
162+
163+
# Internal functions
164+
165+
struct FFT! end
166+
struct RFFT! end
167+
struct IRFFT! end
168+
struct IFFT! end
169+
170+
function plan_transform(transform, A, dims; kwargs...)
171+
if transform isa FFT!
172+
AbstractFFTs.plan_fft!(A, dims; kwargs...)
173+
elseif transform isa IFFT!
174+
AbstractFFTs.plan_ifft!(A, dims; kwargs...)
175+
else
176+
throw(ArgumentError("Unknown transform type: $transform"))
177+
end
178+
end
179+
function apply_fft!(out_part, in_part, transform, dim)
180+
plan = plan_transform(transform, in_part, dim)
181+
LinearAlgebra.mul!(out_part, plan, in_part)
182+
return
183+
end
184+
apply_fft!(inout_part, transform, dim) = apply_fft!(inout_part, inout_part, transform, dim)
185+
186+
## 2D
187+
function __fft!(A::DMatrix{T}, B::DMatrix{T}, dims) where T
188+
A_parts = A.chunks
189+
B_parts = B.chunks
190+
191+
Dagger.spawn_datadeps() do
192+
for idx in eachindex(A_parts)
193+
Dagger.@spawn name="apply_fft!(dim 1)" apply_fft!(InOut(A_parts[idx]), FFT!(), dims[1])
194+
end
195+
end
196+
197+
copyto!(B, A)
198+
Dagger.spawn_datadeps() do
199+
for idx in eachindex(B_parts)
200+
Dagger.@spawn name="apply_fft!(dim 2)" apply_fft!(InOut(B_parts[idx]), FFT!(), dims[2])
201+
end
202+
end
203+
204+
return
205+
end
206+
function __ifft!(A::DMatrix{T}, B::DMatrix{T}, dims) where T
207+
A_parts = A.chunks
208+
B_parts = B.chunks
209+
210+
Dagger.spawn_datadeps() do
211+
for idx in eachindex(A_parts)
212+
Dagger.@spawn name="apply_ifft!(dim 1)" apply_fft!(InOut(A_parts[idx]), IFFT!(), dims[1])
213+
end
214+
end
215+
216+
copyto!(B, A)
217+
Dagger.spawn_datadeps() do
218+
for idx in eachindex(B_parts)
219+
Dagger.@spawn name="apply_ifft!(dim 2)" apply_fft!(InOut(B_parts[idx]), IFFT!(), dims[2])
220+
end
221+
end
222+
223+
return
224+
end
225+
226+
## 3D
227+
function __fft!(::Pencil, A::DArray{T,3}, B::DArray{T,3}, C::DArray{T,3}, dims) where T
228+
A_parts = A.chunks
229+
B_parts = B.chunks
230+
C_parts = C.chunks
231+
232+
Dagger.spawn_datadeps() do
233+
for idx in eachindex(A_parts)
234+
Dagger.@spawn name="apply_fft!(dim 1)" apply_fft!(InOut(A_parts[idx]), FFT!(), dims[1])
235+
end
236+
end
237+
238+
copyto!(B, A)
239+
Dagger.spawn_datadeps() do
240+
for idx in eachindex(B_parts)
241+
Dagger.@spawn name="apply_fft!(dim 2)" apply_fft!(InOut(B_parts[idx]), FFT!(), dims[2])
242+
end
243+
end
244+
245+
copyto!(C, B)
246+
Dagger.spawn_datadeps() do
247+
for idx in eachindex(C_parts)
248+
Dagger.@spawn name="apply_fft!(dim 3)" apply_fft!(InOut(C_parts[idx]), FFT!(), dims[3])
249+
end
250+
end
251+
252+
return
253+
end
254+
function __fft!(::Slab, A::DArray{T,3}, B::DArray{T,3}, dims) where T
255+
A_parts = A.chunks
256+
B_parts = B.chunks
257+
258+
Dagger.spawn_datadeps() do
259+
for idx in eachindex(A_parts)
260+
Dagger.@spawn name="apply_fft!(dim 1&2)" apply_fft!(InOut(A_parts[idx]), FFT!(), (dims[1], dims[2]))
261+
end
262+
end
263+
264+
copyto!(B, A)
265+
Dagger.spawn_datadeps() do
266+
for idx in eachindex(B_parts)
267+
Dagger.@spawn name="apply_fft!(dim 3)" apply_fft!(InOut(B_parts[idx]), FFT!(), dims[3])
268+
end
269+
end
270+
271+
return
272+
end
273+
function __ifft!(::Pencil, A::DArray{T,3}, B::DArray{T,3}, C::DArray{T,3}, dims) where T
274+
A_parts = A.chunks
275+
B_parts = B.chunks
276+
C_parts = C.chunks
277+
278+
Dagger.spawn_datadeps() do
279+
for idx in eachindex(A_parts)
280+
Dagger.@spawn name="apply_ifft!(dim 3)" apply_fft!(InOut(A_parts[idx]), IFFT!(), dims[3])
281+
end
282+
end
283+
284+
copyto!(B, A)
285+
Dagger.spawn_datadeps() do
286+
for idx in eachindex(B_parts)
287+
Dagger.@spawn name="apply_ifft!(dim 2)" apply_fft!(InOut(B_parts[idx]), IFFT!(), dims[2])
288+
end
289+
end
290+
291+
copyto!(C, B)
292+
Dagger.spawn_datadeps() do
293+
for idx in eachindex(C_parts)
294+
Dagger.@spawn name="apply_ifft!(dim 1)" apply_fft!(InOut(C_parts[idx]), IFFT!(), dims[1])
295+
end
296+
end
297+
298+
return
299+
end
300+
function __ifft!(::Slab, A::DArray{T,3}, B::DArray{T,3}, dims) where T
301+
A_parts = A.chunks
302+
B_parts = B.chunks
303+
304+
Dagger.spawn_datadeps() do
305+
for idx in eachindex(A_parts)
306+
Dagger.@spawn name="apply_ifft!(dim 3)" apply_fft!(InOut(A_parts[idx]), IFFT!(), dims[3])
307+
end
308+
end
309+
310+
copyto!(B, A)
311+
Dagger.spawn_datadeps() do
312+
for idx in eachindex(B_parts)
313+
Dagger.@spawn name="apply_ifft!(dim 1&2)" apply_fft!(InOut(B_parts[idx]), IFFT!(), (dims[1], dims[2]))
314+
end
315+
end
316+
317+
return
318+
end
319+
320+
end # module AbstractFFTsExt

test/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ CondaPkg = "992eb4ea-22a4-4c89-a5bb-47a3300528ab"
55
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
66
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
77
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
8+
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
89
GraphViz = "f526b714-d49f-11e8-06ff-31ed36ee7ee0"
910
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
1011
JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1"

0 commit comments

Comments
 (0)