diff --git a/src/algos.jl b/src/algos.jl index 13fb626..233c992 100644 --- a/src/algos.jl +++ b/src/algos.jl @@ -187,40 +187,56 @@ Power of 4 FFT, in place """ function fft_pow4!(out::AbstractVector{T}, in::AbstractVector{U}, N::Int, start_out::Int, stride_out::Int, start_in::Int, stride_in::Int, w::T) where {T, U} - plusi = sign(imag(w))*im minusi = -sign(imag(w))*im - if N == 4 - out[start_out + 0] = in[start_in] + in[start_in + stride_in] + in[start_in + 2*stride_in] + in[start_in + 3*stride_in] - out[start_out + stride_out] = in[start_in] + in[start_in + stride_in]*plusi - in[start_in + 2*stride_in] + in[start_in + 3*stride_in]*minusi - out[start_out + 2*stride_out] = in[start_in] - in[start_in + stride_in] + in[start_in + 2*stride_in] - in[start_in + 3*stride_in] - out[start_out + 3*stride_out] = in[start_in] + in[start_in + stride_in]*minusi - in[start_in + 2*stride_in] + in[start_in + 3*stride_in]*plusi + @inbounds if N == 4 + xee = in[start_in] + xoe = in[start_in + stride_in] + xeo = in[start_in + 2*stride_in] + xoo = in[start_in + 3*stride_in] + xee_p_xeo = xee + xeo + xee_m_xeo = xee - xeo + xoe_p_xoo = xoe + xoo + xoe_m_xoo = -(xoe - xoo)*minusi + out[start_out] = xee_p_xeo + xoe_p_xoo + out[start_out + stride_out] = xee_m_xeo + xoe_m_xoo + out[start_out + 2*stride_out] = xee_p_xeo - xoe_p_xoo + out[start_out + 3*stride_out] = xee_m_xeo - xoe_m_xoo return end m = N ÷ 4 - @muladd fft_pow4!(out, in, m, start_out , stride_out, start_in , stride_in*4, w^4) - @muladd fft_pow4!(out, in, m, start_out + m*stride_out, stride_out, start_in + stride_in, stride_in*4, w^4) - @muladd fft_pow4!(out, in, m, start_out + 2*m*stride_out, stride_out, start_in + 2*stride_in, stride_in*4, w^4) - @muladd fft_pow4!(out, in, m, start_out + 3*m*stride_out, stride_out, start_in + 3*stride_in, stride_in*4, w^4) - w1 = w w2 = w*w1 w3 = w*w2 - wk1 = wk2 = wk3 = one(T) + w4 = w*w3 + + fft_pow4!(out, in, m, start_out , stride_out, start_in , stride_in*4, w4) + fft_pow4!(out, in, m, start_out + m*stride_out, stride_out, start_in + stride_in, stride_in*4, w4) + fft_pow4!(out, in, m, start_out + 2*m*stride_out, stride_out, start_in + 2*stride_in, stride_in*4, w4) + fft_pow4!(out, in, m, start_out + 3*m*stride_out, stride_out, start_in + 3*stride_in, stride_in*4, w4) + + wkoe = wkeo = wkoo = one(T) @inbounds for k in 0:m-1 - @muladd k0 = start_out + k*stride_out - @muladd k1 = start_out + (k+m)*stride_out - @muladd k2 = start_out + (k+2*m)*stride_out - @muladd k3 = start_out + (k+3*m)*stride_out - y_k0, y_k1, y_k2, y_k3 = out[k0], out[k1], out[k2], out[k3] - @muladd out[k0] = (y_k0 + y_k2*wk2) + (y_k1*wk1 + y_k3*wk3) - @muladd out[k1] = (y_k0 - y_k2*wk2) + (y_k1*wk1 - y_k3*wk3) * plusi - @muladd out[k2] = (y_k0 + y_k2*wk2) - (y_k1*wk1 + y_k3*wk3) - @muladd out[k3] = (y_k0 - y_k2*wk2) + (y_k1*wk1 - y_k3*wk3) * minusi - wk1 *= w1 - wk2 *= w2 - wk3 *= w3 + kee = start_out + k * stride_out + koe = start_out + (k + m) * stride_out + keo = start_out + (k + 2 * m) * stride_out + koo = start_out + (k + 3 * m) * stride_out + y_kee, y_koe, y_keo, y_koo = out[kee], out[koe], out[keo], out[koo] + ỹ_keo = y_keo*wkeo + ỹ_koe = y_koe*wkoe + ỹ_koo = y_koo*wkoo + y_kee_p_y_keo = y_kee + ỹ_keo + y_kee_m_y_keo = y_kee - ỹ_keo + ỹ_koe_p_ỹ_koo = ỹ_koe + ỹ_koo + ỹ_koe_m_ỹ_koo = -(ỹ_koe - ỹ_koo) * minusi + out[kee] = y_kee_p_y_keo + ỹ_koe_p_ỹ_koo + out[koe] = y_kee_m_y_keo + ỹ_koe_m_ỹ_koo + out[keo] = y_kee_p_y_keo - ỹ_koe_p_ỹ_koo + out[koo] = y_kee_m_y_keo - ỹ_koe_m_ỹ_koo + wkoe *= w1 + wkeo *= w2 + wkoo *= w3 end end