Skip to content

CuFFT Base.:(*) overload types are too generic #3018

@abussy

Description

@abussy

In lib/cufft/fft.jl, the Base.:(*) operator is overloaded:

CUDA.jl/lib/cufft/fft.jl

Lines 359 to 377 in 9528a33

function Base.:(*)(p::CuFFTPlan{T,S,K,false}, x::DenseCuArray{S1,M}) where {T,S,K,S1,M}
if T<:Real
# Out-of-place complex-to-real FFT will always overwrite input x.
# We copy the input x in an auxiliary buffer.
z = p.buffer
copyto!(z, x)
else
if S1 != S
# Convert to the expected input type.
z = copy1(S, x)
else
z = x
end
end
assert_applicable(p, z)
y = CuArray{T,M}(undef, p.output_size)
unsafe_execute_trailing!(p, z, y)
y
end

The type parametrization is probably too broad, as x::DenseCuArray{S1,M} where {S1, M} will cover any dense CuArray of any element type. In particular, S1 can stand for types not supported by CuFFT.

In DFTK, we overload the Base.:(*) operator for abstract FFT plans and abstract arrays of ForwardDiff.Dual (here). Unfortunately, that fails on NVIDIA GPUs because CuFFT's overload takes precedence, even though ForwardDiff.Dual is not a type that CuFFT should support.

@mfherbst for reference.

Metadata

Metadata

Assignees

No one assigned

    Labels

    cuda librariesStuff about CUDA library wrappers.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions