Skip to content

Commit 61b139b

Browse files
authored
feat: generalize Ops.fft (#1358)
* feat: generalize Ops.fft * test: fft and irfft with real inputs * fix: generalize for unsorted dims
1 parent 780370e commit 61b139b

File tree

3 files changed

+78
-80
lines changed

3 files changed

+78
-80
lines changed

ext/ReactantAbstractFFTsExt.jl

Lines changed: 46 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -1,95 +1,70 @@
11
module ReactantAbstractFFTsExt
22

33
using AbstractFFTs: AbstractFFTs
4-
using Reactant: Reactant, MLIR, Ops, TracedRArray
4+
using Reactant: Reactant, MLIR, Ops, AnyTracedRArray, TracedRArray, TracedUtils
55

6-
function check_contiguous_innermost_dims(dims, N)
7-
@assert sort([dims...]) == [dims...] "un-sorted dims are not supported"
8-
all(i -> dims[i] == dims[i - 1] + 1, 2:(length(dims))) || return false
9-
dims[1] != 1 && return false
10-
return true
6+
function __permutation_to_move_dims_to_end(dims, N::Integer)
7+
perm = [i for i in 1:N if i Set(dims)]
8+
append!(perm, reverse(dims))
9+
return perm
1110
end
1211

13-
function compute_correct_pdims(x::AbstractArray, dims::Int)
14-
counter = 0
15-
return ntuple(ndims(x)) do i
16-
i == 1 && return dims
17-
counter += 1
18-
return counter
19-
end
20-
end
12+
__is_valid_stablehlo_fft_dims(dim::Integer, N::Integer) = dim == N
2113

22-
function compute_correct_pdims(x::AbstractArray, dims)
23-
counter = 0
24-
return ntuple(ndims(x)) do i
25-
i length(dims) && return dims[i]
26-
counter += 1
27-
while counter dims
28-
counter += 1
29-
end
30-
return counter
31-
end
14+
function __is_valid_stablehlo_fft_dims(dims, N::Integer)
15+
return collect(dims) == collect(N:-1:(N - length(dims) + 1))
3216
end
3317

3418
for op in (:rfft, :fft, :ifft)
35-
mode = uppercase(string(op))
36-
@eval function AbstractFFTs.$(op)(x::TracedRArray, dims)
37-
@assert maximum(dims) ndims(x) "dims out of range"
38-
if dims isa Integer
39-
if dims != 1
40-
pdims = compute_correct_pdims(x, dims)
41-
return permutedims(
42-
AbstractFFTs.$(op)(permutedims(x, pdims), 1), invperm(pdims)
43-
)
44-
end
45-
return generalized_fft(x, $(mode), nothing, length(dims))
46-
end
47-
if !check_contiguous_innermost_dims(dims, ndims(x))
48-
pdims = compute_correct_pdims(x, dims)
49-
return permutedims(
50-
AbstractFFTs.$(op)(permutedims(x, pdims), 1:length(dims)), invperm(pdims)
19+
@eval function AbstractFFTs.$(op)(x::AnyTracedRArray, dims)
20+
@assert maximum(dims) <= ndims(x) "Invalid dimensions for fft: $(dims)"
21+
22+
fft_lengths = Int64[size(x, dim) for dim in reverse(dims)]
23+
if __is_valid_stablehlo_fft_dims(dims, ndims(x))
24+
return Ops.fft(
25+
TracedUtils.materialize_traced_array(x);
26+
type=$(uppercase(string(op))),
27+
length=fft_lengths,
5128
)
5229
end
53-
return generalized_fft(x, $(mode), nothing, length(dims))
30+
perm = __permutation_to_move_dims_to_end(dims, ndims(x))
31+
return permutedims(
32+
Ops.fft(
33+
TracedUtils.materialize_traced_array(permutedims(x, perm));
34+
type=$(uppercase(string(op))),
35+
length=fft_lengths,
36+
),
37+
invperm(perm),
38+
)
5439
end
5540
end
5641

5742
for op in (:irfft,)
5843
mode = uppercase(string(op))
59-
@eval function AbstractFFTs.$(op)(x::TracedRArray, d::Int, dims)
60-
@assert maximum(dims) ndims(x) "dims out of range"
61-
if dims isa Integer
62-
if dims != 1
63-
pdims = compute_correct_pdims(x, dims)
64-
return permutedims(
65-
AbstractFFTs.$(op)(permutedims(x, pdims), d, 1), invperm(pdims)
66-
)
67-
end
68-
return generalized_fft(x, $(mode), d, length(dims))
69-
end
70-
if !check_contiguous_innermost_dims(dims, ndims(x))
71-
pdims = compute_correct_pdims(x, dims)
72-
return permutedims(
73-
AbstractFFTs.$(op)(permutedims(x, pdims), d, 1:length(dims)), invperm(pdims)
44+
45+
@eval function AbstractFFTs.$(op)(x::AnyTracedRArray, d::Integer, dims)
46+
@assert maximum(dims) <= ndims(x) "Invalid dimensions for irfft: $(dims)"
47+
48+
fft_lengths = vcat(Int64[size(x, dim) for dim in reverse(dims[2:end])], d)
49+
50+
if __is_valid_stablehlo_fft_dims(dims, ndims(x))
51+
return Ops.fft(
52+
TracedUtils.materialize_traced_array(x);
53+
type=$(uppercase(string(op))),
54+
length=fft_lengths,
7455
)
7556
end
76-
return generalized_fft(x, $(mode), d, length(dims))
77-
end
78-
end
7957

80-
function generalized_fft(x::TracedRArray{T,N}, mode::String, d, first_n::Int) where {T,N}
81-
if d === nothing
82-
@assert mode ("RFFT", "FFT", "IFFT")
83-
fft_length = [size(x, i) for i in 1:first_n]
84-
else
85-
@assert mode == "IRFFT"
86-
fft_length = [i == 1 ? d : size(x, i) for i in 1:first_n]
58+
perm = __permutation_to_move_dims_to_end(dims, ndims(x))
59+
return permutedims(
60+
Ops.fft(
61+
TracedUtils.materialize_traced_array(permutedims(x, perm));
62+
type=$(uppercase(string(op))),
63+
length=fft_lengths,
64+
),
65+
invperm(perm),
66+
)
8767
end
88-
89-
x = permutedims(x, reverse(1:N))
90-
reverse!(fft_length)
91-
x = Ops.fft(x; type=mode, length=fft_length)
92-
return permutedims(x, reverse(1:N))
9368
end
9469

9570
end

src/Ops.jl

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -666,6 +666,7 @@ end
666666
return TracedRNumber{U}((), res)
667667
end
668668

669+
# TODO: See https://github.com/jax-ml/jax/blob/6c18aa8a468e35b8c11b101dceaa43d05b497177/jax/_src/numpy/fft.py#L106
669670
@noinline function fft(
670671
x::TracedRArray{T,N};
671672
type::String,
@@ -675,8 +676,10 @@ end
675676
@assert 1 <= Base.length(length) <= 3 "fft only supports up to rank 3"
676677

677678
if type ("FFT", "IFFT")
678-
@assert T <: Complex
679-
Tout = T
679+
if !(T <: Complex)
680+
x = Ops.complex(x, fill(T(0), size(x); location); location)
681+
end
682+
Tout = Base.complex(T)
680683
rsize = size(x)
681684
elseif type == "RFFT"
682685
@assert T <: Real
@@ -686,8 +689,10 @@ end
686689
Tuple(rsize)
687690
end
688691
elseif type == "IRFFT"
689-
@assert T <: Complex
690-
Tout = Base.Base.real(T)
692+
if !(T <: Complex)
693+
x = Ops.complex(x, fill(T(0), size(x); location); location)
694+
end
695+
Tout = Base.real(T)
691696
rsize = let rsize = collect(Int64, size(x))
692697
rsize[(end - Base.length(length) + 1):end] = length
693698
Tuple(rsize)

test/integration/fft.jl

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1-
using FFTW, Reactant
1+
using FFTW, Reactant, Test
22

33
@testset "fft" begin
44
x = rand(ComplexF32, 2, 2, 3, 4)
55
x_ra = Reactant.to_rarray(x)
66

7-
@test_throws AssertionError @jit(fft(x_ra))
7+
@test_throws AssertionError @jit(fft(x_ra)) # TODO: support this
88

99
x = rand(ComplexF32, 2, 3, 4)
1010
x_ra = Reactant.to_rarray(x)
@@ -15,18 +15,27 @@ using FFTW, Reactant
1515
@test @jit(fft(x_ra, (2, 3))) fft(x, (2, 3))
1616
@test @jit(fft(x_ra, (1, 3))) fft(x, (1, 3))
1717

18-
@test_throws AssertionError @jit(fft(x_ra, (3, 2)))
18+
@test @jit(fft(x_ra, (3, 2))) fft(x, (3, 2))
1919
@test_throws AssertionError @jit(fft(x_ra, (1, 4)))
2020

2121
y_ra = @jit(fft(x_ra))
2222
@test @jit(ifft(y_ra)) x
23+
24+
@testset "fft real input" begin
25+
x = rand(Float32, 2, 3, 4)
26+
x_ra = Reactant.to_rarray(x)
27+
28+
@test @jit(fft(x_ra)) fft(x)
29+
@test @jit(fft(x_ra, (1, 2))) fft(x, (1, 2))
30+
@test @jit(fft(x_ra, (1, 2, 3))) fft(x, (1, 2, 3))
31+
end
2332
end
2433

2534
@testset "rfft" begin
2635
x = rand(2, 2, 3, 4)
2736
x_ra = Reactant.to_rarray(x)
2837

29-
@test_throws AssertionError @jit(rfft(x_ra))
38+
@test_throws AssertionError @jit(rfft(x_ra)) # TODO: support this
3039

3140
x = rand(2, 3, 4)
3241
x_ra = Reactant.to_rarray(x)
@@ -37,10 +46,19 @@ end
3746
@test @jit(rfft(x_ra, (2, 3))) rfft(x, (2, 3))
3847
@test @jit(rfft(x_ra, (1, 3))) rfft(x, (1, 3))
3948

40-
@test_throws AssertionError @jit(rfft(x_ra, (3, 2)))
49+
@test @jit(rfft(x_ra, (3, 2))) rfft(x, (3, 2))
4150
@test_throws AssertionError @jit(rfft(x_ra, (1, 4)))
4251

4352
y_ra = @jit(rfft(x_ra))
4453
@test @jit(irfft(y_ra, 2)) x
4554
@test @jit(irfft(y_ra, 3)) irfft(rfft(x), 3)
55+
56+
@testset "irfft real input" begin
57+
y_ra_real = @jit(real(y_ra))
58+
y_real = Array(y_ra_real)
59+
60+
@test @jit(rfft(x_ra)) rfft(x)
61+
@test @jit(rfft(x_ra, (1, 2))) rfft(x, (1, 2))
62+
@test @jit(rfft(x_ra, (1, 2, 3))) rfft(x, (1, 2, 3))
63+
end
4664
end

0 commit comments

Comments
 (0)