-
Notifications
You must be signed in to change notification settings - Fork 267
Open
Labels
cuda librariesStuff about CUDA library wrappers.Stuff about CUDA library wrappers.
Description
In lib/cufft/fft.jl, the Base.:(*) operator is overloaded:
Lines 359 to 377 in 9528a33
| function Base.:(*)(p::CuFFTPlan{T,S,K,false}, x::DenseCuArray{S1,M}) where {T,S,K,S1,M} | |
| if T<:Real | |
| # Out-of-place complex-to-real FFT will always overwrite input x. | |
| # We copy the input x in an auxiliary buffer. | |
| z = p.buffer | |
| copyto!(z, x) | |
| else | |
| if S1 != S | |
| # Convert to the expected input type. | |
| z = copy1(S, x) | |
| else | |
| z = x | |
| end | |
| end | |
| assert_applicable(p, z) | |
| y = CuArray{T,M}(undef, p.output_size) | |
| unsafe_execute_trailing!(p, z, y) | |
| y | |
| end |
The type parametrization is probably too broad, as x::DenseCuArray{S1,M} where {S1, M} will cover any dense CuArray of any element type. In particular, S1 can stand for types not supported by CuFFT.
In DFTK, we overload the Base.:(*) operator for abstract FFT plans and abstract arrays of ForwardDiff.Dual (here). Unfortunately, that fails on NVIDIA GPUs because CuFFT's overload takes precedence, even though ForwardDiff.Dual is not a type that CuFFT should support.
@mfherbst for reference.
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
cuda librariesStuff about CUDA library wrappers.Stuff about CUDA library wrappers.