|
1 | 1 | module ReactantAbstractFFTsExt
|
2 | 2 |
|
3 | 3 | using AbstractFFTs: AbstractFFTs
|
4 |
| -using Reactant: Reactant, MLIR, Ops, TracedRArray |
| 4 | +using Reactant: Reactant, MLIR, Ops, AnyTracedRArray, TracedRArray, TracedUtils |
5 | 5 |
|
6 |
| -function check_contiguous_innermost_dims(dims, N) |
7 |
| - @assert sort([dims...]) == [dims...] "un-sorted dims are not supported" |
8 |
| - all(i -> dims[i] == dims[i - 1] + 1, 2:(length(dims))) || return false |
9 |
| - dims[1] != 1 && return false |
10 |
| - return true |
| 6 | +function __permutation_to_move_dims_to_end(dims, N::Integer) |
| 7 | + perm = [i for i in 1:N if i ∉ Set(dims)] |
| 8 | + append!(perm, reverse(dims)) |
| 9 | + return perm |
11 | 10 | end
|
12 | 11 |
|
13 |
| -function compute_correct_pdims(x::AbstractArray, dims::Int) |
14 |
| - counter = 0 |
15 |
| - return ntuple(ndims(x)) do i |
16 |
| - i == 1 && return dims |
17 |
| - counter += 1 |
18 |
| - return counter |
19 |
| - end |
20 |
| -end |
| 12 | +__is_valid_stablehlo_fft_dims(dim::Integer, N::Integer) = dim == N |
21 | 13 |
|
22 |
| -function compute_correct_pdims(x::AbstractArray, dims) |
23 |
| - counter = 0 |
24 |
| - return ntuple(ndims(x)) do i |
25 |
| - i ≤ length(dims) && return dims[i] |
26 |
| - counter += 1 |
27 |
| - while counter ∈ dims |
28 |
| - counter += 1 |
29 |
| - end |
30 |
| - return counter |
31 |
| - end |
| 14 | +function __is_valid_stablehlo_fft_dims(dims, N::Integer) |
| 15 | + return collect(dims) == collect(N:-1:(N - length(dims) + 1)) |
32 | 16 | end
|
33 | 17 |
|
34 | 18 | for op in (:rfft, :fft, :ifft)
|
35 |
| - mode = uppercase(string(op)) |
36 |
| - @eval function AbstractFFTs.$(op)(x::TracedRArray, dims) |
37 |
| - @assert maximum(dims) ≤ ndims(x) "dims out of range" |
38 |
| - if dims isa Integer |
39 |
| - if dims != 1 |
40 |
| - pdims = compute_correct_pdims(x, dims) |
41 |
| - return permutedims( |
42 |
| - AbstractFFTs.$(op)(permutedims(x, pdims), 1), invperm(pdims) |
43 |
| - ) |
44 |
| - end |
45 |
| - return generalized_fft(x, $(mode), nothing, length(dims)) |
46 |
| - end |
47 |
| - if !check_contiguous_innermost_dims(dims, ndims(x)) |
48 |
| - pdims = compute_correct_pdims(x, dims) |
49 |
| - return permutedims( |
50 |
| - AbstractFFTs.$(op)(permutedims(x, pdims), 1:length(dims)), invperm(pdims) |
| 19 | + @eval function AbstractFFTs.$(op)(x::AnyTracedRArray, dims) |
| 20 | + @assert maximum(dims) <= ndims(x) "Invalid dimensions for fft: $(dims)" |
| 21 | + |
| 22 | + fft_lengths = Int64[size(x, dim) for dim in reverse(dims)] |
| 23 | + if __is_valid_stablehlo_fft_dims(dims, ndims(x)) |
| 24 | + return Ops.fft( |
| 25 | + TracedUtils.materialize_traced_array(x); |
| 26 | + type=$(uppercase(string(op))), |
| 27 | + length=fft_lengths, |
51 | 28 | )
|
52 | 29 | end
|
53 |
| - return generalized_fft(x, $(mode), nothing, length(dims)) |
| 30 | + perm = __permutation_to_move_dims_to_end(dims, ndims(x)) |
| 31 | + return permutedims( |
| 32 | + Ops.fft( |
| 33 | + TracedUtils.materialize_traced_array(permutedims(x, perm)); |
| 34 | + type=$(uppercase(string(op))), |
| 35 | + length=fft_lengths, |
| 36 | + ), |
| 37 | + invperm(perm), |
| 38 | + ) |
54 | 39 | end
|
55 | 40 | end
|
56 | 41 |
|
57 | 42 | for op in (:irfft,)
|
58 | 43 | mode = uppercase(string(op))
|
59 |
| - @eval function AbstractFFTs.$(op)(x::TracedRArray, d::Int, dims) |
60 |
| - @assert maximum(dims) ≤ ndims(x) "dims out of range" |
61 |
| - if dims isa Integer |
62 |
| - if dims != 1 |
63 |
| - pdims = compute_correct_pdims(x, dims) |
64 |
| - return permutedims( |
65 |
| - AbstractFFTs.$(op)(permutedims(x, pdims), d, 1), invperm(pdims) |
66 |
| - ) |
67 |
| - end |
68 |
| - return generalized_fft(x, $(mode), d, length(dims)) |
69 |
| - end |
70 |
| - if !check_contiguous_innermost_dims(dims, ndims(x)) |
71 |
| - pdims = compute_correct_pdims(x, dims) |
72 |
| - return permutedims( |
73 |
| - AbstractFFTs.$(op)(permutedims(x, pdims), d, 1:length(dims)), invperm(pdims) |
| 44 | + |
| 45 | + @eval function AbstractFFTs.$(op)(x::AnyTracedRArray, d::Integer, dims) |
| 46 | + @assert maximum(dims) <= ndims(x) "Invalid dimensions for irfft: $(dims)" |
| 47 | + |
| 48 | + fft_lengths = vcat(Int64[size(x, dim) for dim in reverse(dims[2:end])], d) |
| 49 | + |
| 50 | + if __is_valid_stablehlo_fft_dims(dims, ndims(x)) |
| 51 | + return Ops.fft( |
| 52 | + TracedUtils.materialize_traced_array(x); |
| 53 | + type=$(uppercase(string(op))), |
| 54 | + length=fft_lengths, |
74 | 55 | )
|
75 | 56 | end
|
76 |
| - return generalized_fft(x, $(mode), d, length(dims)) |
77 |
| - end |
78 |
| -end |
79 | 57 |
|
80 |
| -function generalized_fft(x::TracedRArray{T,N}, mode::String, d, first_n::Int) where {T,N} |
81 |
| - if d === nothing |
82 |
| - @assert mode ∈ ("RFFT", "FFT", "IFFT") |
83 |
| - fft_length = [size(x, i) for i in 1:first_n] |
84 |
| - else |
85 |
| - @assert mode == "IRFFT" |
86 |
| - fft_length = [i == 1 ? d : size(x, i) for i in 1:first_n] |
| 58 | + perm = __permutation_to_move_dims_to_end(dims, ndims(x)) |
| 59 | + return permutedims( |
| 60 | + Ops.fft( |
| 61 | + TracedUtils.materialize_traced_array(permutedims(x, perm)); |
| 62 | + type=$(uppercase(string(op))), |
| 63 | + length=fft_lengths, |
| 64 | + ), |
| 65 | + invperm(perm), |
| 66 | + ) |
87 | 67 | end
|
88 |
| - |
89 |
| - x = permutedims(x, reverse(1:N)) |
90 |
| - reverse!(fft_length) |
91 |
| - x = Ops.fft(x; type=mode, length=fft_length) |
92 |
| - return permutedims(x, reverse(1:N)) |
93 | 68 | end
|
94 | 69 |
|
95 | 70 | end
|
0 commit comments