Skip to content

Commit 266c88f

Browse files
committed
Merge branch 'master' into adjoint
2 parents fe3b06a + b5109aa commit 266c88f

File tree

11 files changed

+129
-43
lines changed

11 files changed

+129
-43
lines changed

.github/dependabot.yml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates
2+
version: 2
3+
updates:
4+
- package-ecosystem: "github-actions"
5+
directory: "/" # Location of package manifests
6+
schedule:
7+
interval: "weekly"

.github/workflows/CI.yml

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,32 +15,23 @@ jobs:
1515
version:
1616
- '1.0'
1717
- '1'
18-
# - 'nightly'
18+
- 'nightly'
1919
os:
2020
- ubuntu-latest
2121
- macOS-latest
2222
- windows-latest
2323
arch:
2424
- x64
2525
steps:
26-
- uses: actions/checkout@v2
26+
- uses: actions/checkout@v3
2727
- uses: julia-actions/setup-julia@v1
2828
with:
2929
version: ${{ matrix.version }}
3030
arch: ${{ matrix.arch }}
31-
- uses: actions/cache@v1
32-
env:
33-
cache-name: cache-artifacts
34-
with:
35-
path: ~/.julia/artifacts
36-
key: ${{ runner.os }}-test-${{ env.cache-name }}-${{ hashFiles('**/Project.toml') }}
37-
restore-keys: |
38-
${{ runner.os }}-test-${{ env.cache-name }}-
39-
${{ runner.os }}-test-
40-
${{ runner.os }}-
31+
- uses: julia-actions/cache@v1
4132
- uses: julia-actions/julia-buildpkg@v1
4233
- uses: julia-actions/julia-runtest@v1
4334
- uses: julia-actions/julia-processcoverage@v1
44-
- uses: codecov/codecov-action@v1
35+
- uses: codecov/codecov-action@v3
4536
with:
4637
file: lcov.info

.github/workflows/Documenter.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@ jobs:
1010
name: Documentation
1111
runs-on: ubuntu-latest
1212
steps:
13-
- uses: actions/checkout@v2
14-
- uses: julia-actions/julia-buildpkg@latest
15-
- uses: julia-actions/julia-docdeploy@latest
13+
- uses: actions/checkout@v3
14+
- uses: julia-actions/julia-buildpkg@v1
15+
- uses: julia-actions/julia-docdeploy@v1
1616
env:
1717
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
1818
DOCUMENTER_KEY: ${{ secrets.DOCUMENTER_KEY }}

.github/workflows/IntegrationTest.yml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,16 @@ jobs:
1818
os: [ubuntu-latest]
1919
package:
2020
- {user: JuliaMath, repo: FFTW.jl}
21+
- {user: JuliaApproximation, repo: FastTransforms.jl}
2122
steps:
22-
- uses: actions/checkout@v2
23+
- uses: actions/checkout@v3
2324
- uses: julia-actions/setup-julia@v1
2425
with:
2526
version: ${{ matrix.julia-version }}
2627
arch: x64
27-
- uses: julia-actions/julia-buildpkg@latest
28+
- uses: julia-actions/julia-buildpkg@v1
2829
- name: Clone Downstream
29-
uses: actions/checkout@v2
30+
uses: actions/checkout@v3
3031
with:
3132
repository: ${{ matrix.package.user }}/${{ matrix.package.repo }}
3233
path: downstream

.github/workflows/Invalidations.yml

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
name: Invalidations
2+
3+
on:
4+
pull_request:
5+
6+
concurrency:
7+
# Skip intermediate builds: always.
8+
# Cancel intermediate builds: always.
9+
group: ${{ github.workflow }}-${{ github.ref }}
10+
cancel-in-progress: true
11+
12+
jobs:
13+
evaluate:
14+
# Only run on PRs to the default branch.
15+
# In the PR trigger above branches can be specified only explicitly whereas this check should work for master, main, or any other default branch
16+
if: github.base_ref == github.event.repository.default_branch
17+
runs-on: ubuntu-latest
18+
steps:
19+
- uses: julia-actions/setup-julia@v1
20+
with:
21+
version: '1'
22+
- uses: actions/checkout@v3
23+
- uses: julia-actions/julia-buildpkg@v1
24+
- uses: julia-actions/julia-invalidations@v1
25+
id: invs_pr
26+
27+
- uses: actions/checkout@v3
28+
with:
29+
ref: ${{ github.event.repository.default_branch }}
30+
- uses: julia-actions/julia-buildpkg@v1
31+
- uses: julia-actions/julia-invalidations@v1
32+
id: invs_default
33+
34+
- name: Report invalidation counts
35+
run: |
36+
echo "Invalidations on default branch: ${{ steps.invs_default.outputs.total }} (${{ steps.invs_default.outputs.deps }} via deps)" >> $GITHUB_STEP_SUMMARY
37+
echo "This branch: ${{ steps.invs_pr.outputs.total }} (${{ steps.invs_pr.outputs.deps }} via deps)" >> $GITHUB_STEP_SUMMARY
38+
- name: Check whether the number of invalidations increased
39+
if: steps.invs_pr.outputs.total > steps.invs_default.outputs.total
40+
run: exit 1

Project.toml

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,28 @@
11
name = "AbstractFFTs"
22
uuid = "621f4979-c628-5d54-868e-fcf4e3e8185c"
3-
version = "1.2.1"
3+
version = "1.3.2"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
77
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
88

9+
[weakdeps]
10+
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
11+
12+
[extensions]
13+
AbstractFFTsChainRulesCoreExt = "ChainRulesCore"
14+
915
[compat]
1016
ChainRulesCore = "1"
1117
julia = "^1.0"
1218

1319
[extras]
20+
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
1421
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
1522
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
1623
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1724
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1825
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
1926

2027
[targets]
21-
test = ["ChainRulesTestUtils", "FiniteDifferences", "Random", "Test", "Unitful"]
28+
test = ["ChainRulesCore", "ChainRulesTestUtils", "FiniteDifferences", "Random", "Test", "Unitful"]

README.md

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,12 @@
22

33
A general framework for fast Fourier transforms (FFTs) in Julia.
44

5-
[![Travis](https://travis-ci.org/JuliaMath/AbstractFFTs.jl.svg?branch=master)](https://travis-ci.org/JuliaMath/AbstractFFTs.jl)
6-
[![Coveralls](https://coveralls.io/repos/github/JuliaMath/AbstractFFTs.jl/badge.svg?branch=master)](https://coveralls.io/github/JuliaMath/AbstractFFTs.jl?branch=master)
5+
[![GHA](https://github.com/JuliaMath/AbstractFFTs.jl/workflows/CI/badge.svg)](https://github.com/JuliaMath/AbstractFFTs.jl/actions?query=workflow%3ACI+branch%3Amaster)
6+
[![Codecov](http://codecov.io/github/JuliaMath/AbstractFFTs.jl/coverage.svg?branch=master)](http://codecov.io/github/JuliaMath/AbstractFFTs.jl?branch=master)
77

88
Documentation:
99
[![](https://img.shields.io/badge/docs-stable-blue.svg)](https://JuliaMath.github.io/AbstractFFTs.jl/stable)
10-
[![](https://img.shields.io/badge/docs-latest-blue.svg)](https://JuliaMath.github.io/AbstractFFTs.jl/latest)
10+
[![](https://img.shields.io/badge/docs-latest-blue.svg)](https://JuliaMath.github.io/AbstractFFTs.jl/dev)
1111

1212
This package is mainly not intended to be used directly.
1313
Instead, developers of packages that implement FFTs (such as [FFTW.jl](https://github.com/JuliaMath/FFTW.jl) or [FastTransforms.jl](https://github.com/JuliaApproximation/FastTransforms.jl))
@@ -17,3 +17,4 @@ This allows multiple FFT packages to co-exist with the same underlying `fft(x)`
1717
## Developer information
1818

1919
To define a new FFT implementation in your own module, see [defining a new implementation](https://juliamath.github.io/AbstractFFTs.jl/stable/implementations/#Defining-a-new-implementation).
20+

src/chainrules.jl renamed to ext/AbstractFFTsChainRulesCoreExt.jl

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
1-
# ffts
1+
module AbstractFFTsChainRulesCoreExt
2+
3+
using AbstractFFTs
4+
import ChainRulesCore
5+
26
function ChainRulesCore.frule((_, Δx, _), ::typeof(fft), x::AbstractArray, dims)
37
y = fft(x, dims)
48
Δy = fft(Δx, dims)
@@ -33,7 +37,9 @@ function ChainRulesCore.rrule(::typeof(rfft), x::AbstractArray{<:Real}, dims)
3337

3438
project_x = ChainRulesCore.ProjectTo(x)
3539
function rfft_pullback(ȳ)
36-
= project_x(brfft(ChainRulesCore.unthunk(ȳ) ./ scale, d, dims))
40+
ybar = ChainRulesCore.unthunk(ȳ)
41+
_scale = convert(typeof(ybar),scale)
42+
= project_x(brfft(ybar ./ _scale, d, dims))
3743
return ChainRulesCore.NoTangent(), x̄, ChainRulesCore.NoTangent()
3844
end
3945
return y, rfft_pullback
@@ -46,7 +52,7 @@ function ChainRulesCore.frule((_, Δx, _), ::typeof(ifft), x::AbstractArray, dim
4652
end
4753
function ChainRulesCore.rrule(::typeof(ifft), x::AbstractArray, dims)
4854
y = ifft(x, dims)
49-
invN = normalization(y, dims)
55+
invN = AbstractFFTs.normalization(y, dims)
5056
project_x = ChainRulesCore.ProjectTo(x)
5157
function ifft_pullback(ȳ)
5258
= project_x(invN .* fft(ChainRulesCore.unthunk(ȳ), dims))
@@ -66,7 +72,7 @@ function ChainRulesCore.rrule(::typeof(irfft), x::AbstractArray, d::Int, dims)
6672
# compute scaling factors
6773
halfdim = first(dims)
6874
n = size(x, halfdim)
69-
invN = normalization(y, dims)
75+
invN = AbstractFFTs.normalization(y, dims)
7076
twoinvN = 2 * invN
7177
scale = reshape(
7278
[i == 1 || (i == n && 2 * (i - 1) == d) ? invN : twoinvN for i in 1:n],
@@ -75,7 +81,9 @@ function ChainRulesCore.rrule(::typeof(irfft), x::AbstractArray, d::Int, dims)
7581

7682
project_x = ChainRulesCore.ProjectTo(x)
7783
function irfft_pullback(ȳ)
78-
= project_x(scale .* rfft(real.(ChainRulesCore.unthunk(ȳ)), dims))
84+
ybar = ChainRulesCore.unthunk(ȳ)
85+
_scale = convert(typeof(ybar),scale)
86+
= project_x(_scale .* rfft(real.(ybar), dims))
7987
return ChainRulesCore.NoTangent(), x̄, ChainRulesCore.NoTangent(), ChainRulesCore.NoTangent()
8088
end
8189
return y, irfft_pullback
@@ -152,12 +160,12 @@ function ChainRulesCore.rrule(::typeof(ifftshift), x::AbstractArray, dims)
152160
end
153161

154162
# plans
155-
function ChainRulesCore.frule((_, _, Δx), ::typeof(*), P::Plan, x::AbstractArray)
163+
function ChainRulesCore.frule((_, _, Δx), ::typeof(*), P::AbstractFFTs.Plan, x::AbstractArray)
156164
y = P * x
157165
Δy = P * Δx
158166
return y, Δy
159167
end
160-
function ChainRulesCore.rrule(::typeof(*), P::Plan, x::AbstractArray)
168+
function ChainRulesCore.rrule(::typeof(*), P::AbstractFFTs.Plan, x::AbstractArray)
161169
y = P * x
162170
project_x = ChainRulesCore.ProjectTo(x)
163171
Pt = P'
@@ -168,22 +176,25 @@ function ChainRulesCore.rrule(::typeof(*), P::Plan, x::AbstractArray)
168176
return y, mul_plan_pullback
169177
end
170178

171-
function ChainRulesCore.frule((_, ΔP, Δx), ::typeof(*), P::ScaledPlan, x::AbstractArray)
179+
function ChainRulesCore.frule((_, ΔP, Δx), ::typeof(*), P::AbstractFFTs.ScaledPlan, x::AbstractArray)
172180
y = P * x
173181
Δy = P * Δx .+ (ΔP.scale / P.scale) .* y
174182
return y, Δy
175183
end
176-
function ChainRulesCore.rrule(::typeof(*), P::ScaledPlan, x::AbstractArray)
184+
function ChainRulesCore.rrule(::typeof(*), P::AbstractFFTs.ScaledPlan, x::AbstractArray)
177185
y = P * x
178186
Pt = P'
179187
scale = P.scale
180188
project_x = ChainRulesCore.ProjectTo(x)
181189
project_scale = ChainRulesCore.ProjectTo(scale)
182190
function mul_scaledplan_pullback(ȳ)
183191
= ChainRulesCore.@thunk(project_x(Pt * ȳ))
184-
scale_tangent = ChainRulesCore.@thunk(project_scale(dot(y, ȳ) / conj(scale)))
192+
scale_tangent = ChainRulesCore.@thunk(project_scale(AbstractFFTs.dot(y, ȳ) / conj(scale)))
185193
plan_tangent = ChainRulesCore.Tangent{typeof(P)}(;p=ChainRulesCore.NoTangent(), scale=scale_tangent)
186194
return ChainRulesCore.NoTangent(), plan_tangent, x̄
187195
end
188196
return y, mul_scaledplan_pullback
189197
end
198+
199+
end # module
200+

src/AbstractFFTs.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
module AbstractFFTs
22

3-
import ChainRulesCore
4-
53
export fft, ifft, bfft, fft!, ifft!, bfft!,
64
plan_fft, plan_ifft, plan_bfft, plan_fft!, plan_ifft!, plan_bfft!,
75
rfft, irfft, brfft, plan_rfft, plan_irfft, plan_brfft,
86
fftdims, fftshift, ifftshift, fftshift!, ifftshift!, Frequencies, fftfreq, rfftfreq
97

108
include("definitions.jl")
11-
include("chainrules.jl")
9+
10+
if !isdefined(Base, :get_extension)
11+
include("../ext/AbstractFFTsChainRulesCoreExt.jl")
12+
end
1213

1314
end # module

src/definitions.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ _to1(::Tuple, x) = copy1(eltype(x), x)
6060
for f in (:fft, :bfft, :ifft, :fft!, :bfft!, :ifft!, :rfft)
6161
pf = Symbol("plan_", f)
6262
@eval begin
63-
$f(x::AbstractArray) = (y = to1(x); $pf(y) * y)
63+
$f(x::AbstractArray) = $f(x, 1:ndims(x))
6464
$f(x::AbstractArray, region) = (y = to1(x); $pf(y, region) * y)
6565
$pf(x::AbstractArray; kws...) = (y = to1(x); $pf(y, 1:ndims(y); kws...))
6666
end
@@ -208,9 +208,9 @@ bfft!
208208
for f in (:fft, :bfft, :ifft)
209209
pf = Symbol("plan_", f)
210210
@eval begin
211-
$f(x::AbstractArray{<:Real}, region=1:ndims(x)) = $f(complexfloat(x), region)
211+
$f(x::AbstractArray{<:Real}, region) = $f(complexfloat(x), region)
212212
$pf(x::AbstractArray{<:Real}, region; kws...) = $pf(complexfloat(x), region; kws...)
213-
$f(x::AbstractArray{<:Complex{<:Union{Integer,Rational}}}, region=1:ndims(x)) = $f(complexfloat(x), region)
213+
$f(x::AbstractArray{<:Complex{<:Union{Integer,Rational}}}, region) = $f(complexfloat(x), region)
214214
$pf(x::AbstractArray{<:Complex{<:Union{Integer,Rational}}}, region; kws...) = $pf(complexfloat(x), region; kws...)
215215
end
216216
end
@@ -299,16 +299,16 @@ LinearAlgebra.mul!(y::AbstractArray, p::ScaledPlan, x::AbstractArray) =
299299
for f in (:brfft, :irfft)
300300
pf = Symbol("plan_", f)
301301
@eval begin
302-
$f(x::AbstractArray, d::Integer) = $pf(x, d) * x
302+
$f(x::AbstractArray, d::Integer) = $f(x, d, 1:ndims(x))
303303
$f(x::AbstractArray, d::Integer, region) = $pf(x, d, region) * x
304304
$pf(x::AbstractArray, d::Integer;kws...) = $pf(x, d, 1:ndims(x);kws...)
305305
end
306306
end
307307

308308
for f in (:brfft, :irfft)
309309
@eval begin
310-
$f(x::AbstractArray{<:Real}, d::Integer, region=1:ndims(x)) = $f(complexfloat(x), d, region)
311-
$f(x::AbstractArray{<:Complex{<:Union{Integer,Rational}}}, d::Integer, region=1:ndims(x)) = $f(complexfloat(x), d, region)
310+
$f(x::AbstractArray{<:Real}, d::Integer, region) = $f(complexfloat(x), d, region)
311+
$f(x::AbstractArray{<:Complex{<:Union{Integer,Rational}}}, d::Integer, region) = $f(complexfloat(x), d, region)
312312
end
313313
end
314314

0 commit comments

Comments
 (0)