diff --git a/Project.toml b/Project.toml index 217ab01a4..51815a84c 100644 --- a/Project.toml +++ b/Project.toml @@ -31,6 +31,7 @@ UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" [weakdeps] AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" +AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" Colors = "5ae59095-9a9b-59fe-a467-6f913c188581" DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" @@ -44,6 +45,7 @@ Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d" [extensions] +AbstractFFTsExt = "AbstractFFTs" CUDAExt = "CUDA" DistributionsExt = "Distributions" GraphVizExt = "GraphViz" @@ -58,6 +60,7 @@ ROCExt = "AMDGPU" [compat] AMDGPU = "1" +AbstractFFTs = "1.5.0" Adapt = "4" CUDA = "3, 4, 5" Colors = "0.12, 0.13" diff --git a/docs/src/darray.md b/docs/src/darray.md index 715a6cbe8..e813bfeda 100644 --- a/docs/src/darray.md +++ b/docs/src/darray.md @@ -447,3 +447,7 @@ From `LinearAlgebra`: - `mul!` (In-place Matrix-Matrix multiply) - `cholesky`/`cholesky!` (In-place/Out-of-place Cholesky factorization) - `lu`/`lu!` (In-place/Out-of-place LU factorization (`NoPivot` only)) + +From `AbstractFFTs`: +- `fft`/`fft!` +- `ifft`/`ifft!` \ No newline at end of file diff --git a/ext/AbstractFFTsExt.jl b/ext/AbstractFFTsExt.jl new file mode 100644 index 000000000..5f24ba584 --- /dev/null +++ b/ext/AbstractFFTsExt.jl @@ -0,0 +1,320 @@ +module AbstractFFTsExt + +import Dagger +import Dagger: DArray, DVector, DMatrix, Blocks, AutoBlocks, InOut +import AbstractFFTs +import LinearAlgebra + +abstract type Decomposition end +struct Pencil <: Decomposition end +struct Slab <: Decomposition end + +# High-level interface + +## TODO: Add optimized 1D algorithm + +## 1D out-of-place +AbstractFFTs.fft(A::DVector) = DVector(AbstractFFTs.fft(collect(A))) +AbstractFFTs.ifft(A::DVector) = DVector(AbstractFFTs.ifft(collect(A))) + +## 1D in-place +function AbstractFFTs.fft!(DA::DVector{T}) where T + A = Vector{T}(undef, length(DA)) + copyto!(A, DA) + AbstractFFTs.fft!(A) + copyto!(DA, A) + return DA +end +function AbstractFFTs.ifft!(DA::DVector{T}) where T + A = Vector{T}(undef, length(DA)) + copyto!(A, DA) + AbstractFFTs.ifft!(A) + copyto!(DA, A) + return DA +end + +## 2D out-of-place +function AbstractFFTs.fft(DA::DMatrix, dims=(1, 2)) + DB = similar(DA) + _fft!(DB, DA, dims) + return DB +end +function AbstractFFTs.ifft(DA::DMatrix, dims=(1, 2)) + DB = similar(DA) + _ifft!(DB, DA, dims) + return DB +end + +## 2D in-place +function AbstractFFTs.fft!(DA::DMatrix{T}, dims=(1, 2)) where T + _fft!(DA, DA, dims) + return DA +end +function AbstractFFTs.ifft!(DA::DMatrix{T}, dims=(1, 2)) where T + _ifft!(DA, DA, dims) + return DA +end + +## 3D out-of-place +function AbstractFFTs.fft(DA::DArray{T,3}, dims=(1, 2, 3); decomp::Union{Decomposition,Symbol}=Pencil()) where T + DB = similar(DA) + _decomp = _to_decomp(decomp) + _fft!(DB, DA, dims; decomp=_decomp) + return DB +end +function AbstractFFTs.ifft(DA::DArray{T,3}, dims=(1, 2, 3); decomp::Union{Decomposition,Symbol}=Pencil()) where T + DB = similar(DA) + _decomp = _to_decomp(decomp) + _ifft!(DB, DA, dims; decomp=_decomp) + return DB +end + +## 3D in-place +function AbstractFFTs.fft!(DA::DArray{T,3}, dims=(1, 2, 3); decomp::Union{Decomposition,Symbol}=Pencil()) where T + _decomp = _to_decomp(decomp) + _fft!(DA, DA, dims; decomp=_decomp) + return DA +end +function AbstractFFTs.ifft!(DA::DArray{T,3}, dims=(1, 2, 3); decomp::Union{Decomposition,Symbol}=Pencil()) where T + _decomp = _to_decomp(decomp) + _ifft!(DA, DA, dims; decomp=_decomp) + return DA +end + +# Mid-level interface + +_to_decomp(decomp::Decomposition) = decomp +function _to_decomp(decomp::Symbol) + if decomp == :pencil + return Pencil() + elseif decomp == :slab + return Slab() + else + throw(ArgumentError("Unknown decomposition type: $decomp\nSupported types: :pencil, :slab")) + end +end + +## 2D +function _fft!(output::DMatrix{T}, input::DMatrix{T}, dims=(1, 2)) where T + N = size(input, 1) + np = length(Dagger.compatible_processors()) + A = zeros(Blocks(N, div(N, np)), T, size(input)) + copyto!(A, input) + B = zeros(Blocks(div(N, np), N), T, size(input)) + __fft!(A, B, dims) + copyto!(output, B) + return output +end +function _ifft!(output::DMatrix{T}, input::DMatrix{T}, dims=(1, 2)) where T + N = size(input, 1) + np = length(Dagger.compatible_processors()) + A = zeros(Blocks(N, div(N, np)), T, size(input)) + copyto!(A, input) + B = zeros(Blocks(div(N, np), N), T, size(input)) + __ifft!(A, B, dims) + copyto!(output, B) + return output +end + +## 3D +function _fft!(output::DArray{T,3}, input::DArray{T,3}, dims=(1, 2, 3); decomp::Decomposition=Pencil()) where T + N = size(input, 1) + np = length(Dagger.compatible_processors()) + if decomp isa Pencil + A = zeros(Blocks(N, div(N, np), div(N, np)), T, size(input)) + B = zeros(Blocks(div(N, np), N, div(N, np)), T, size(input)) + C = zeros(Blocks(div(N, np), div(N, np), N), T, size(input)) + copyto!(A, input) + __fft!(decomp, A, B, C, dims) + copyto!(output, C) + return output + elseif decomp isa Slab + A = zeros(Blocks(N, N, div(N, np)), T, size(input)) + B = zeros(Blocks(div(N, np), div(N, np), N), T, size(input)) + copyto!(A, input) + __fft!(decomp, A, B, dims) + copyto!(output, B) + return output + else + throw(ArgumentError("Unknown decomposition type: $decomp")) + end +end +function _ifft!(output::DArray{T,3}, input::DArray{T,3}, dims=(1, 2, 3); decomp::Decomposition=Pencil()) where T + N = size(input, 1) + np = length(Dagger.compatible_processors()) + if decomp isa Pencil + A = zeros(Blocks(div(N, np), div(N, np), N), T, size(input)) + B = zeros(Blocks(div(N, np), N, div(N, np)), T, size(input)) + C = zeros(Blocks(N, div(N, np), div(N, np)), T, size(input)) + copyto!(A, input) + __ifft!(decomp, A, B, C, dims) + copyto!(output, C) + return output + elseif decomp isa Slab + A = zeros(Blocks(div(N, np), div(N, np), N), T, size(input)) + B = zeros(Blocks(N, N, div(N, np)), T, size(input)) + copyto!(A, input) + __ifft!(decomp, A, B, dims) + copyto!(output, B) + return output + end +end + +# Internal functions + +struct FFT! end +struct RFFT! end +struct IRFFT! end +struct IFFT! end + +function plan_transform(transform, A, dims; kwargs...) + if transform isa FFT! + AbstractFFTs.plan_fft!(A, dims; kwargs...) + elseif transform isa IFFT! + AbstractFFTs.plan_ifft!(A, dims; kwargs...) + else + throw(ArgumentError("Unknown transform type: $transform")) + end +end +function apply_fft!(out_part, in_part, transform, dim) + plan = plan_transform(transform, in_part, dim) + LinearAlgebra.mul!(out_part, plan, in_part) + return +end +apply_fft!(inout_part, transform, dim) = apply_fft!(inout_part, inout_part, transform, dim) + +## 2D +function __fft!(A::DMatrix{T}, B::DMatrix{T}, dims) where T + A_parts = A.chunks + B_parts = B.chunks + + Dagger.spawn_datadeps() do + for idx in eachindex(A_parts) + Dagger.@spawn name="apply_fft!(dim 1)" apply_fft!(InOut(A_parts[idx]), FFT!(), dims[1]) + end + end + + copyto!(B, A) + Dagger.spawn_datadeps() do + for idx in eachindex(B_parts) + Dagger.@spawn name="apply_fft!(dim 2)" apply_fft!(InOut(B_parts[idx]), FFT!(), dims[2]) + end + end + + return +end +function __ifft!(A::DMatrix{T}, B::DMatrix{T}, dims) where T + A_parts = A.chunks + B_parts = B.chunks + + Dagger.spawn_datadeps() do + for idx in eachindex(A_parts) + Dagger.@spawn name="apply_ifft!(dim 1)" apply_fft!(InOut(A_parts[idx]), IFFT!(), dims[1]) + end + end + + copyto!(B, A) + Dagger.spawn_datadeps() do + for idx in eachindex(B_parts) + Dagger.@spawn name="apply_ifft!(dim 2)" apply_fft!(InOut(B_parts[idx]), IFFT!(), dims[2]) + end + end + + return +end + +## 3D +function __fft!(::Pencil, A::DArray{T,3}, B::DArray{T,3}, C::DArray{T,3}, dims) where T + A_parts = A.chunks + B_parts = B.chunks + C_parts = C.chunks + + Dagger.spawn_datadeps() do + for idx in eachindex(A_parts) + Dagger.@spawn name="apply_fft!(dim 1)" apply_fft!(InOut(A_parts[idx]), FFT!(), dims[1]) + end + end + + copyto!(B, A) + Dagger.spawn_datadeps() do + for idx in eachindex(B_parts) + Dagger.@spawn name="apply_fft!(dim 2)" apply_fft!(InOut(B_parts[idx]), FFT!(), dims[2]) + end + end + + copyto!(C, B) + Dagger.spawn_datadeps() do + for idx in eachindex(C_parts) + Dagger.@spawn name="apply_fft!(dim 3)" apply_fft!(InOut(C_parts[idx]), FFT!(), dims[3]) + end + end + + return +end +function __fft!(::Slab, A::DArray{T,3}, B::DArray{T,3}, dims) where T + A_parts = A.chunks + B_parts = B.chunks + + Dagger.spawn_datadeps() do + for idx in eachindex(A_parts) + Dagger.@spawn name="apply_fft!(dim 1&2)" apply_fft!(InOut(A_parts[idx]), FFT!(), (dims[1], dims[2])) + end + end + + copyto!(B, A) + Dagger.spawn_datadeps() do + for idx in eachindex(B_parts) + Dagger.@spawn name="apply_fft!(dim 3)" apply_fft!(InOut(B_parts[idx]), FFT!(), dims[3]) + end + end + + return +end +function __ifft!(::Pencil, A::DArray{T,3}, B::DArray{T,3}, C::DArray{T,3}, dims) where T + A_parts = A.chunks + B_parts = B.chunks + C_parts = C.chunks + + Dagger.spawn_datadeps() do + for idx in eachindex(A_parts) + Dagger.@spawn name="apply_ifft!(dim 3)" apply_fft!(InOut(A_parts[idx]), IFFT!(), dims[3]) + end + end + + copyto!(B, A) + Dagger.spawn_datadeps() do + for idx in eachindex(B_parts) + Dagger.@spawn name="apply_ifft!(dim 2)" apply_fft!(InOut(B_parts[idx]), IFFT!(), dims[2]) + end + end + + copyto!(C, B) + Dagger.spawn_datadeps() do + for idx in eachindex(C_parts) + Dagger.@spawn name="apply_ifft!(dim 1)" apply_fft!(InOut(C_parts[idx]), IFFT!(), dims[1]) + end + end + + return +end +function __ifft!(::Slab, A::DArray{T,3}, B::DArray{T,3}, dims) where T + A_parts = A.chunks + B_parts = B.chunks + + Dagger.spawn_datadeps() do + for idx in eachindex(A_parts) + Dagger.@spawn name="apply_ifft!(dim 3)" apply_fft!(InOut(A_parts[idx]), IFFT!(), dims[3]) + end + end + + copyto!(B, A) + Dagger.spawn_datadeps() do + for idx in eachindex(B_parts) + Dagger.@spawn name="apply_ifft!(dim 1&2)" apply_fft!(InOut(B_parts[idx]), IFFT!(), (dims[1], dims[2])) + end + end + + return +end + +end # module AbstractFFTsExt diff --git a/test/Project.toml b/test/Project.toml index 10b64f326..72daddfb3 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -5,6 +5,7 @@ CondaPkg = "992eb4ea-22a4-4c89-a5bb-47a3300528ab" DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" GraphViz = "f526b714-d49f-11e8-06ff-31ed36ee7ee0" Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1" diff --git a/test/array/fft.jl b/test/array/fft.jl new file mode 100644 index 000000000..54a4afe2a --- /dev/null +++ b/test/array/fft.jl @@ -0,0 +1,141 @@ +using FFTW + +@testset "FFT" begin + @testset for T in (ComplexF64, ComplexF32) + @testset "1D" begin + # Out-of-place + A = rand(T, 100) + DA = DArray(A) + B = fft(A) + DB = fft(DA) + @test DB isa DVector{T} + @test B ≈ collect(DB) + + # In-place + A = rand(T, 100) + DA = DArray(A) + fft!(A) + fft!(DA) + @test A ≈ collect(DA) + end + + @testset "2D" begin + # Out-of-place + A = rand(T, 100, 100) + DA = DArray(A) + B = fft(A) + DB = fft(DA) + @test DB isa DMatrix{T} + @test B ≈ collect(DB) + + # In-place + A = rand(T, 100, 100) + DA = DArray(A) + fft!(A) + fft!(DA) + @test A ≈ collect(DA) + end + + @testset "3D" begin + # Out-of-place (Pencil) + A = rand(T, 100, 100, 100) + DA = DArray(A) + B = fft(A) + DB = fft(DA; decomp=:pencil) + @test DB isa DArray{T, 3} + @test B ≈ collect(DB) + + # Out-of-place (Slab) + A = rand(T, 100, 100, 100) + DA = DArray(A) + B = fft(A) + DB = fft(DA; decomp=:slab) + @test DB isa DArray{T, 3} + @test B ≈ collect(DB) + + # In-place (Pencil) + A = rand(T, 100, 100, 100) + DA = DArray(A) + fft!(A) + fft!(DA; decomp=:pencil) + @test A ≈ collect(DA) + + # In-place (Slab) + A = rand(T, 100, 100, 100) + DA = DArray(A) + fft!(A) + fft!(DA; decomp=:slab) + @test A ≈ collect(DA) + end + end +end + +@testset "IFFT" begin + for T in (ComplexF64, ComplexF32) + @testset "1D" begin + # Out-of-place + A = rand(T, 100) + DA = DArray(A) + B = ifft(A) + DB = ifft(DA) + @test DB isa DVector{T} + @test B ≈ collect(DB) + + # In-place + A = rand(T, 100) + DA = DArray(A) + ifft!(A) + ifft!(DA) + @test A ≈ collect(DA) + end + + @testset "2D" begin + # Out-of-place + A = rand(T, 100, 100) + DA = DArray(A) + B = ifft(A) + DB = ifft(DA) + @test DB isa DMatrix{T} + @test B ≈ collect(DB) + + # In-place + A = rand(T, 100, 100) + DA = DArray(A) + ifft!(A) + ifft!(DA) + @test A ≈ collect(DA) + end + + @testset "3D" begin + # Out-of-place (Pencil) + A = rand(T, 100, 100, 100) + DA = DArray(A) + B = ifft(A) + DB = ifft(DA; decomp=:pencil) + @test DB isa DArray{T, 3} + @test B ≈ collect(DB) + + # Out-of-place (Slab) + A = rand(T, 100, 100, 100) + DA = DArray(A) + B = ifft(A) + DB = ifft(DA; decomp=:slab) + @test DB isa DArray{T, 3} + @test B ≈ collect(DB) + + # In-place (Pencil) + A = rand(T, 100, 100, 100) + DA = DArray(A) + ifft!(A) + ifft!(DA; decomp=:pencil) + @test A ≈ collect(DA) + + # In-place (Slab) + A = rand(T, 100, 100, 100) + DA = DArray(A) + ifft!(A) + ifft!(DA; decomp=:slab) + @test A ≈ collect(DA) + end + end +end \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index a7b7a890a..8ef8cb3a0 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -32,6 +32,7 @@ tests = [ ("Array - LinearAlgebra - LU", "array/linalg/lu.jl"), ("Array - Random", "array/random.jl"), ("Array - Stencils", "array/stencil.jl"), + ("Array - FFT", "array/fft.jl"), ("GPU", "gpu.jl"), ("Caching", "cache.jl"), ("Disk Caching", "diskcaching.jl"),