-
Notifications
You must be signed in to change notification settings - Fork 31
oneMKL DFT support #515
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
oneMKL DFT support #515
Conversation
|
Your PR requires formatting changes to meet the project's style guidelines. Click here to view the suggested changes.diff --git a/lib/mkl/fft.jl b/lib/mkl/fft.jl
index 5f5614b..9d38e11 100644
--- a/lib/mkl/fft.jl
+++ b/lib/mkl/fft.jl
@@ -22,34 +22,34 @@ using ..Support
# Allow implicit conversion of SYCL queue object to raw handle when storing/passing
Base.convert(::Type{syclQueue_t}, q::SYCL.syclQueue) = Base.unsafe_convert(syclQueue_t, q)
-abstract type MKLFFTPlan{T,K,inplace} <: AbstractFFTs.Plan{T} end
+abstract type MKLFFTPlan{T, K, inplace} <: AbstractFFTs.Plan{T} end
-Base.eltype(::MKLFFTPlan{T}) where T = T
-is_inplace(::MKLFFTPlan{<:Any,<:Any,inplace}) where inplace = inplace
+Base.eltype(::MKLFFTPlan{T}) where {T} = T
+is_inplace(::MKLFFTPlan{<:Any, <:Any, inplace}) where {inplace} = inplace
# Forward / inverse flags
const MKLFFT_FORWARD = true
const MKLFFT_INVERSE = false
-mutable struct cMKLFFTPlan{T,K,inplace,N,R,B} <: MKLFFTPlan{T,K,inplace}
+mutable struct cMKLFFTPlan{T, K, inplace, N, R, B} <: MKLFFTPlan{T, K, inplace}
handle::onemklDftDescriptor_t
queue::syclQueue_t
- sz::NTuple{N,Int}
- osz::NTuple{N,Int}
+ sz::NTuple{N, Int}
+ osz::NTuple{N, Int}
realdomain::Bool
- region::NTuple{R,Int}
+ region::NTuple{R, Int}
buffer::B
pinv::Any
end
# Real transforms use separate struct (mirroring AMDGPU style) for buffer staging
-mutable struct rMKLFFTPlan{T,K,inplace,N,R,B} <: MKLFFTPlan{T,K,inplace}
+mutable struct rMKLFFTPlan{T, K, inplace, N, R, B} <: MKLFFTPlan{T, K, inplace}
handle::onemklDftDescriptor_t
queue::syclQueue_t
- sz::NTuple{N,Int}
- osz::NTuple{N,Int}
+ sz::NTuple{N, Int}
+ osz::NTuple{N, Int}
xtype::Symbol
- region::NTuple{R,Int}
+ region::NTuple{R, Int}
buffer::B
pinv::Any
end
@@ -57,40 +57,44 @@ end
# Inverse plan constructors (derive from existing plan)
function normalization_factor(sz, region)
# AbstractFFTs expects inverse to scale by 1/prod(lengths along region)
- prod(ntuple(i-> sz[region[i]], length(region)))
+ return prod(ntuple(i -> sz[region[i]], length(region)))
end
-function plan_inv(p::cMKLFFTPlan{T,MKLFFT_FORWARD,inplace,N,R,B}) where {T,inplace,N,R,B}
- q = cMKLFFTPlan{T,MKLFFT_INVERSE,inplace,N,R,B}(p.handle,p.queue,p.sz,p.osz,p.realdomain,p.region,p.buffer,p)
+function plan_inv(p::cMKLFFTPlan{T, MKLFFT_FORWARD, inplace, N, R, B}) where {T, inplace, N, R, B}
+ q = cMKLFFTPlan{T, MKLFFT_INVERSE, inplace, N, R, B}(p.handle, p.queue, p.sz, p.osz, p.realdomain, p.region, p.buffer, p)
p.pinv = q
- ScaledPlan(q, 1/normalization_factor(p.sz, p.region))
+ return ScaledPlan(q, 1 / normalization_factor(p.sz, p.region))
end
-function plan_inv(p::cMKLFFTPlan{T,MKLFFT_INVERSE,inplace,N,R,B}) where {T,inplace,N,R,B}
- q = cMKLFFTPlan{T,MKLFFT_FORWARD,inplace,N,R,B}(p.handle,p.queue,p.sz,p.osz,p.realdomain,p.region,p.buffer,p)
+function plan_inv(p::cMKLFFTPlan{T, MKLFFT_INVERSE, inplace, N, R, B}) where {T, inplace, N, R, B}
+ q = cMKLFFTPlan{T, MKLFFT_FORWARD, inplace, N, R, B}(p.handle, p.queue, p.sz, p.osz, p.realdomain, p.region, p.buffer, p)
p.pinv = q
- ScaledPlan(q, 1/normalization_factor(p.sz, p.region))
+ return ScaledPlan(q, 1 / normalization_factor(p.sz, p.region))
end
-function plan_inv(p::rMKLFFTPlan{T,MKLFFT_FORWARD,inplace,N,R,B}) where {T,inplace,N,R,B}
- q = rMKLFFTPlan{T,MKLFFT_INVERSE,inplace,N,R,B}(p.handle,p.queue,p.sz,p.osz,:brfft,p.region,p.buffer,p)
+function plan_inv(p::rMKLFFTPlan{T, MKLFFT_FORWARD, inplace, N, R, B}) where {T, inplace, N, R, B}
+ q = rMKLFFTPlan{T, MKLFFT_INVERSE, inplace, N, R, B}(p.handle, p.queue, p.sz, p.osz, :brfft, p.region, p.buffer, p)
p.pinv = q
- ScaledPlan(q, 1/normalization_factor(p.sz, p.region))
+ return ScaledPlan(q, 1 / normalization_factor(p.sz, p.region))
end
-function plan_inv(p::rMKLFFTPlan{T,MKLFFT_INVERSE,inplace,N,R,B}) where {T,inplace,N,R,B}
- q = rMKLFFTPlan{T,MKLFFT_FORWARD,inplace,N,R,B}(p.handle,p.queue,p.sz,p.osz,:rfft,p.region,p.buffer,p)
+function plan_inv(p::rMKLFFTPlan{T, MKLFFT_INVERSE, inplace, N, R, B}) where {T, inplace, N, R, B}
+ q = rMKLFFTPlan{T, MKLFFT_FORWARD, inplace, N, R, B}(p.handle, p.queue, p.sz, p.osz, :rfft, p.region, p.buffer, p)
p.pinv = q
- ScaledPlan(q, 1/normalization_factor(p.sz, p.region))
+ return ScaledPlan(q, 1 / normalization_factor(p.sz, p.region))
end
-function Base.show(io::IO, p::MKLFFTPlan{T,K,inplace}) where {T,K,inplace}
+function Base.show(io::IO, p::MKLFFTPlan{T, K, inplace}) where {T, K, inplace}
print(io, inplace ? "oneMKL FFT in-place " : "oneMKL FFT ", K ? "forward" : "inverse", " plan for ")
- if isempty(p.sz); print(io, "0-dimensional") else print(io, join(p.sz, "×")) end
- print(io, " oneArray of ", T)
+ if isempty(p.sz)
+ print(io, "0-dimensional")
+ else
+ print(io, join(p.sz, "×"))
+ end
+ return print(io, " oneArray of ", T)
end
# Plan constructors
-function _create_descriptor(sz::NTuple{N,Int}, T::Type, complex::Bool) where {N}
- prec = T<:Float64 || T<:ComplexF64 ? ONEMKL_DFT_PRECISION_DOUBLE : ONEMKL_DFT_PRECISION_SINGLE
+function _create_descriptor(sz::NTuple{N, Int}, T::Type, complex::Bool) where {N}
+ prec = T <: Float64 || T <: ComplexF64 ? ONEMKL_DFT_PRECISION_DOUBLE : ONEMKL_DFT_PRECISION_SINGLE
dom = complex ? ONEMKL_DFT_DOMAIN_COMPLEX : ONEMKL_DFT_DOMAIN_REAL
desc_ref = Ref{onemklDftDescriptor_t}()
# Create descriptor for the full array dimensions
@@ -109,8 +113,8 @@ function _create_descriptor(sz::NTuple{N,Int}, T::Type, complex::Bool) where {N}
end
# Complex plans
-function plan_fft(X::oneAPI.oneArray{T,N}, region) where {T<:Union{ComplexF32,ComplexF64},N}
- R = length(region); reg = NTuple{R,Int}(region)
+function plan_fft(X::oneAPI.oneArray{T, N}, region) where {T <: Union{ComplexF32, ComplexF64}, N}
+ R = length(region); reg = NTuple{R, Int}(region)
# For now, only support full transforms (all dimensions)
if reg != ntuple(identity, N)
error("Partial dimension FFT not yet supported. Region $reg must be $(ntuple(identity, N))")
@@ -119,20 +123,20 @@ function plan_fft(X::oneAPI.oneArray{T,N}, region) where {T<:Union{ComplexF32,Co
onemklDftSetValueConfigValue(desc, ONEMKL_DFT_PARAM_PLACEMENT, ONEMKL_DFT_VALUE_NOT_INPLACE)
if N > 1
# Column-major strides: stride along dimension i is product of sizes of previous dims
- strides = Vector{Int64}(undef, N+1); strides[1]=0
+ strides = Vector{Int64}(undef, N + 1); strides[1] = 0
prod = 1
@inbounds for i in 1:N
- strides[i+1] = prod
- prod *= size(X,i)
+ strides[i + 1] = prod
+ prod *= size(X, i)
end
onemklDftSetValueInt64Array(desc, ONEMKL_DFT_PARAM_FWD_STRIDES, pointer(strides), length(strides))
onemklDftSetValueInt64Array(desc, ONEMKL_DFT_PARAM_BWD_STRIDES, pointer(strides), length(strides))
end
stc = onemklDftCommit(desc, q); stc == 0 || error("commit failed ($stc)")
- return cMKLFFTPlan{T,MKLFFT_FORWARD,false,N,R,Nothing}(desc,q,size(X),size(X),false,reg,nothing,nothing)
+ return cMKLFFTPlan{T, MKLFFT_FORWARD, false, N, R, Nothing}(desc, q, size(X), size(X), false, reg, nothing, nothing)
end
-function plan_bfft(X::oneAPI.oneArray{T,N}, region) where {T<:Union{ComplexF32,ComplexF64},N}
- R = length(region); reg = NTuple{R,Int}(region)
+function plan_bfft(X::oneAPI.oneArray{T, N}, region) where {T <: Union{ComplexF32, ComplexF64}, N}
+ R = length(region); reg = NTuple{R, Int}(region)
# For now, only support full transforms (all dimensions)
if reg != ntuple(identity, N)
error("Partial dimension FFT not yet supported. Region $reg must be $(ntuple(identity, N))")
@@ -140,87 +144,87 @@ function plan_bfft(X::oneAPI.oneArray{T,N}, region) where {T<:Union{ComplexF32,C
desc, q = _create_descriptor(size(X), T, true)
onemklDftSetValueConfigValue(desc, ONEMKL_DFT_PARAM_PLACEMENT, ONEMKL_DFT_VALUE_NOT_INPLACE)
if N > 1
- strides = Vector{Int64}(undef, N+1); strides[1]=0; prod=1
+ strides = Vector{Int64}(undef, N + 1); strides[1] = 0; prod = 1
@inbounds for i in 1:N
- strides[i+1]=prod; prod*=size(X,i)
+ strides[i + 1] = prod; prod *= size(X, i)
end
onemklDftSetValueInt64Array(desc, ONEMKL_DFT_PARAM_FWD_STRIDES, pointer(strides), length(strides))
onemklDftSetValueInt64Array(desc, ONEMKL_DFT_PARAM_BWD_STRIDES, pointer(strides), length(strides))
end
stc = onemklDftCommit(desc, q); stc == 0 || error("commit failed ($stc)")
- return cMKLFFTPlan{T,MKLFFT_INVERSE,false,N,R,Nothing}(desc,q,size(X),size(X),false,reg,nothing,nothing)
+ return cMKLFFTPlan{T, MKLFFT_INVERSE, false, N, R, Nothing}(desc, q, size(X), size(X), false, reg, nothing, nothing)
end
# In-place (provide separate methods)
-function plan_fft!(X::oneAPI.oneArray{T,N}, region) where {T<:Union{ComplexF32,ComplexF64},N}
- R = length(region); reg = NTuple{R,Int}(region)
+function plan_fft!(X::oneAPI.oneArray{T, N}, region) where {T <: Union{ComplexF32, ComplexF64}, N}
+ R = length(region); reg = NTuple{R, Int}(region)
# For now, only support full transforms (all dimensions)
if reg != ntuple(identity, N)
error("Partial dimension FFT not yet supported. Region $reg must be $(ntuple(identity, N))")
end
- desc,q = _create_descriptor(size(X),T,true)
+ desc, q = _create_descriptor(size(X), T, true)
onemklDftSetValueConfigValue(desc, ONEMKL_DFT_PARAM_PLACEMENT, ONEMKL_DFT_VALUE_INPLACE)
if N > 1
- strides = Vector{Int64}(undef, N+1); strides[1]=0; prod=1
+ strides = Vector{Int64}(undef, N + 1); strides[1] = 0; prod = 1
@inbounds for i in 1:N
- strides[i+1]=prod; prod*=size(X,i)
+ strides[i + 1] = prod; prod *= size(X, i)
end
onemklDftSetValueInt64Array(desc, ONEMKL_DFT_PARAM_FWD_STRIDES, pointer(strides), length(strides))
onemklDftSetValueInt64Array(desc, ONEMKL_DFT_PARAM_BWD_STRIDES, pointer(strides), length(strides))
end
stc = onemklDftCommit(desc, q); stc == 0 || error("commit failed ($stc)")
- cMKLFFTPlan{T,MKLFFT_FORWARD,true,N,R,Nothing}(desc,q,size(X),size(X),false,reg,nothing,nothing)
+ return cMKLFFTPlan{T, MKLFFT_FORWARD, true, N, R, Nothing}(desc, q, size(X), size(X), false, reg, nothing, nothing)
end
-function plan_bfft!(X::oneAPI.oneArray{T,N}, region) where {T<:Union{ComplexF32,ComplexF64},N}
- R = length(region); reg = NTuple{R,Int}(region)
+function plan_bfft!(X::oneAPI.oneArray{T, N}, region) where {T <: Union{ComplexF32, ComplexF64}, N}
+ R = length(region); reg = NTuple{R, Int}(region)
# For now, only support full transforms (all dimensions)
if reg != ntuple(identity, N)
error("Partial dimension FFT not yet supported. Region $reg must be $(ntuple(identity, N))")
end
- desc,q = _create_descriptor(size(X),T,true)
+ desc, q = _create_descriptor(size(X), T, true)
onemklDftSetValueConfigValue(desc, ONEMKL_DFT_PARAM_PLACEMENT, ONEMKL_DFT_VALUE_INPLACE)
if N > 1
- strides = Vector{Int64}(undef, N+1); strides[1]=0; prod=1
+ strides = Vector{Int64}(undef, N + 1); strides[1] = 0; prod = 1
@inbounds for i in 1:N
- strides[i+1]=prod; prod*=size(X,i)
+ strides[i + 1] = prod; prod *= size(X, i)
end
onemklDftSetValueInt64Array(desc, ONEMKL_DFT_PARAM_FWD_STRIDES, pointer(strides), length(strides))
onemklDftSetValueInt64Array(desc, ONEMKL_DFT_PARAM_BWD_STRIDES, pointer(strides), length(strides))
end
stc = onemklDftCommit(desc, q); stc == 0 || error("commit failed ($stc)")
- cMKLFFTPlan{T,MKLFFT_INVERSE,true,N,R,Nothing}(desc,q,size(X),size(X),false,reg,nothing,nothing)
+ return cMKLFFTPlan{T, MKLFFT_INVERSE, true, N, R, Nothing}(desc, q, size(X), size(X), false, reg, nothing, nothing)
end
# Real input methods - convert to complex like FFTW does
-function plan_fft(X::oneAPI.oneArray{T,N}, region) where {T<:Union{Float32,Float64},N}
+function plan_fft(X::oneAPI.oneArray{T, N}, region) where {T <: Union{Float32, Float64}, N}
CT = Complex{T}
# Create a complex plan by converting the real array to complex
X_complex = oneAPI.oneArray{CT}(undef, size(X))
- plan_fft(X_complex, region)
+ return plan_fft(X_complex, region)
end
-function plan_bfft(X::oneAPI.oneArray{T,N}, region) where {T<:Union{Float32,Float64},N}
+function plan_bfft(X::oneAPI.oneArray{T, N}, region) where {T <: Union{Float32, Float64}, N}
CT = Complex{T}
# Create a complex plan by converting the real array to complex
X_complex = oneAPI.oneArray{CT}(undef, size(X))
- plan_bfft(X_complex, region)
+ return plan_bfft(X_complex, region)
end
-function plan_fft!(X::oneAPI.oneArray{T,N}, region) where {T<:Union{Float32,Float64},N}
+function plan_fft!(X::oneAPI.oneArray{T, N}, region) where {T <: Union{Float32, Float64}, N}
error("In-place FFT not supported for real input arrays. Use plan_fft instead.")
end
-function plan_bfft!(X::oneAPI.oneArray{T,N}, region) where {T<:Union{Float32,Float64},N}
+function plan_bfft!(X::oneAPI.oneArray{T, N}, region) where {T <: Union{Float32, Float64}, N}
error("In-place FFT not supported for real input arrays. Use plan_bfft instead.")
end
# Real forward (out-of-place) - supports multi-dimensional transforms
-function plan_rfft(X::oneAPI.oneArray{T,N}, region) where {T<:Union{Float32,Float64},N}
+function plan_rfft(X::oneAPI.oneArray{T, N}, region) where {T <: Union{Float32, Float64}, N}
# Convert region to tuple if it's a range
if isa(region, AbstractUnitRange)
region = tuple(region...)
end
- R = length(region); reg = NTuple{R,Int}(region)
+ R = length(region); reg = NTuple{R, Int}(region)
# For single dimension transforms, use the optimized oneMKL real FFT
if R == 1 && reg[1] == 1
@@ -234,12 +238,12 @@ function plan_rfft(X::oneAPI.oneArray{T,N}, region) where {T<:Union{Float32,Floa
end
# Single-dimension real FFT using oneMKL (optimized path)
-function _plan_rfft_1d(X::oneAPI.oneArray{T,N}, reg::NTuple{1,Int}) where {T<:Union{Float32,Float64},N}
+function _plan_rfft_1d(X::oneAPI.oneArray{T, N}, reg::NTuple{1, Int}) where {T <: Union{Float32, Float64}, N}
# Create 1D descriptor for the transform dimension
- desc,q = _create_descriptor((size(X, reg[1]),), T, false)
+ desc, q = _create_descriptor((size(X, reg[1]),), T, false)
xdims = size(X)
# output along first dim becomes N/2+1
- ydims = Base.setindex(xdims, div(xdims[1],2)+1, 1)
+ ydims = Base.setindex(xdims, div(xdims[1], 2) + 1, 1)
buffer = oneAPI.oneArray{Complex{T}}(undef, ydims)
onemklDftSetValueConfigValue(desc, ONEMKL_DFT_PARAM_PLACEMENT, ONEMKL_DFT_VALUE_NOT_INPLACE)
@@ -255,18 +259,18 @@ function _plan_rfft_1d(X::oneAPI.oneArray{T,N}, reg::NTuple{1,Int}) where {T<:Un
stc = onemklDftCommit(desc, q); stc == 0 || error("commit failed ($stc)")
R = length(reg)
- rMKLFFTPlan{T,MKLFFT_FORWARD,false,N,R,typeof(buffer)}(desc,q,xdims,ydims,:rfft,reg,buffer,nothing)
+ return rMKLFFTPlan{T, MKLFFT_FORWARD, false, N, R, typeof(buffer)}(desc, q, xdims, ydims, :rfft, reg, buffer, nothing)
end
# Multi-dimensional real FFT using complex FFT approach
-struct ComplexBasedRealFFTPlan{T,N,R} <: MKLFFTPlan{T,MKLFFT_FORWARD,false}
- complex_plan::cMKLFFTPlan{Complex{T},MKLFFT_FORWARD,false,N,R,Nothing}
- sz::NTuple{N,Int}
- osz::NTuple{N,Int}
- region::NTuple{R,Int}
+struct ComplexBasedRealFFTPlan{T, N, R} <: MKLFFTPlan{T, MKLFFT_FORWARD, false}
+ complex_plan::cMKLFFTPlan{Complex{T}, MKLFFT_FORWARD, false, N, R, Nothing}
+ sz::NTuple{N, Int}
+ osz::NTuple{N, Int}
+ region::NTuple{R, Int}
end
-function _plan_rfft_nd(X::oneAPI.oneArray{T,N}, reg::NTuple{R,Int}) where {T<:Union{Float32,Float64},N,R}
+function _plan_rfft_nd(X::oneAPI.oneArray{T, N}, reg::NTuple{R, Int}) where {T <: Union{Float32, Float64}, N, R}
# Create complex version for planning
X_complex = oneAPI.oneArray{Complex{T}}(undef, size(X))
complex_plan = plan_fft(X_complex, reg)
@@ -281,18 +285,22 @@ function _plan_rfft_nd(X::oneAPI.oneArray{T,N}, reg::NTuple{R,Int}) where {T<:Un
end
end
- ComplexBasedRealFFTPlan{T,N,R}(complex_plan, xdims, ydims, reg)
+ return ComplexBasedRealFFTPlan{T, N, R}(complex_plan, xdims, ydims, reg)
end
# Show method for complex-based plan
function Base.show(io::IO, p::ComplexBasedRealFFTPlan{T}) where {T}
print(io, "oneMKL FFT forward plan for ")
- if isempty(p.sz); print(io, "0-dimensional") else print(io, join(p.sz, "×")) end
- print(io, " oneArray of ", T, " (multi-dimensional via complex FFT)")
+ if isempty(p.sz)
+ print(io, "0-dimensional")
+ else
+ print(io, join(p.sz, "×"))
+ end
+ return print(io, " oneArray of ", T, " (multi-dimensional via complex FFT)")
end
# Execution for complex-based real FFT plan
-function Base.:*(p::ComplexBasedRealFFTPlan{T,N,R}, X::oneAPI.oneArray{T}) where {T,N,R}
+function Base.:*(p::ComplexBasedRealFFTPlan{T, N, R}, X::oneAPI.oneArray{T}) where {T, N, R}
# Convert to complex
X_complex = Complex{T}.(X)
@@ -316,14 +324,13 @@ function Base.:*(p::ComplexBasedRealFFTPlan{T,N,R}, X::oneAPI.oneArray{T}) where
end
-
# Real inverse (complex->real) requires complex input shape - supports multi-dimensional transforms
-function plan_brfft(X::oneAPI.oneArray{T,N}, d::Integer, region) where {T<:Union{ComplexF32,ComplexF64},N}
+function plan_brfft(X::oneAPI.oneArray{T, N}, d::Integer, region) where {T <: Union{ComplexF32, ComplexF64}, N}
# Convert region to tuple if it's a range
if isa(region, AbstractUnitRange)
region = tuple(region...)
end
- R = length(region); reg = NTuple{R,Int}(region)
+ R = length(region); reg = NTuple{R, Int}(region)
# For single dimension transforms along first dim, use optimized oneMKL path
if R == 1 && reg[1] == 1
@@ -335,13 +342,13 @@ function plan_brfft(X::oneAPI.oneArray{T,N}, d::Integer, region) where {T<:Union
end
# Single-dimension real inverse FFT using oneMKL (optimized path)
-function _plan_brfft_1d(X::oneAPI.oneArray{T,N}, d::Integer, reg::NTuple{1,Int}) where {T<:Union{ComplexF32,ComplexF64},N}
+function _plan_brfft_1d(X::oneAPI.oneArray{T, N}, d::Integer, reg::NTuple{1, Int}) where {T <: Union{ComplexF32, ComplexF64}, N}
# Extract underlying real type R from Complex{R}
@assert T <: Complex
RT = T.parameters[1]
# Create 1D descriptor for the transform dimension
- desc,q = _create_descriptor((d,), RT, false)
+ desc, q = _create_descriptor((d,), RT, false)
xdims = size(X)
ydims = Base.setindex(xdims, d, 1)
buffer = oneAPI.oneArray{T}(undef, xdims) # copy for safety
@@ -355,19 +362,19 @@ function _plan_brfft_1d(X::oneAPI.oneArray{T,N}, d::Integer, reg::NTuple{1,Int})
stc = onemklDftCommit(desc, q); stc == 0 || error("commit failed ($stc)")
R = length(reg)
- rMKLFFTPlan{T,MKLFFT_INVERSE,false,N,R,typeof(buffer)}(desc,q,xdims,ydims,:brfft,reg,buffer,nothing)
+ return rMKLFFTPlan{T, MKLFFT_INVERSE, false, N, R, typeof(buffer)}(desc, q, xdims, ydims, :brfft, reg, buffer, nothing)
end
# Multi-dimensional real inverse FFT using complex FFT approach
-struct ComplexBasedRealIFFTPlan{T,N,R} <: MKLFFTPlan{T,MKLFFT_INVERSE,false}
- complex_plan::cMKLFFTPlan{T,MKLFFT_INVERSE,false,N,R,Nothing}
- sz::NTuple{N,Int}
- osz::NTuple{N,Int}
- region::NTuple{R,Int}
+struct ComplexBasedRealIFFTPlan{T, N, R} <: MKLFFTPlan{T, MKLFFT_INVERSE, false}
+ complex_plan::cMKLFFTPlan{T, MKLFFT_INVERSE, false, N, R, Nothing}
+ sz::NTuple{N, Int}
+ osz::NTuple{N, Int}
+ region::NTuple{R, Int}
d::Int # Original size of the reduced dimension
end
-function _plan_brfft_nd(X::oneAPI.oneArray{T,N}, d::Integer, reg::NTuple{R,Int}) where {T<:Union{ComplexF32,ComplexF64},N,R}
+function _plan_brfft_nd(X::oneAPI.oneArray{T, N}, d::Integer, reg::NTuple{R, Int}) where {T <: Union{ComplexF32, ComplexF64}, N, R}
# Calculate the full complex array size (before real FFT reduction)
xdims = size(X)
full_complex_dims = ntuple(N) do i
@@ -382,18 +389,22 @@ function _plan_brfft_nd(X::oneAPI.oneArray{T,N}, d::Integer, reg::NTuple{R,Int})
X_complex_full = oneAPI.oneArray{T}(undef, full_complex_dims)
complex_plan = plan_bfft(X_complex_full, reg)
- ComplexBasedRealIFFTPlan{T,N,R}(complex_plan, xdims, full_complex_dims, reg, d)
+ return ComplexBasedRealIFFTPlan{T, N, R}(complex_plan, xdims, full_complex_dims, reg, d)
end
# Show method for complex-based inverse plan
function Base.show(io::IO, p::ComplexBasedRealIFFTPlan{T}) where {T}
print(io, "oneMKL FFT inverse plan for ")
- if isempty(p.sz); print(io, "0-dimensional") else print(io, join(p.sz, "×")) end
- print(io, " oneArray of ", T, " (multi-dimensional via complex FFT)")
+ if isempty(p.sz)
+ print(io, "0-dimensional")
+ else
+ print(io, join(p.sz, "×"))
+ end
+ return print(io, " oneArray of ", T, " (multi-dimensional via complex FFT)")
end
# Execution for complex-based real inverse FFT plan
-function Base.:*(p::ComplexBasedRealIFFTPlan{T,N,R}, X::oneAPI.oneArray{T}) where {T,N,R}
+function Base.:*(p::ComplexBasedRealIFFTPlan{T, N, R}, X::oneAPI.oneArray{T}) where {T, N, R}
# Reconstruct full complex array by exploiting conjugate symmetry
# This is a simplified approach - for full accuracy, we'd need to properly
# reconstruct the conjugate symmetric part
@@ -435,7 +446,7 @@ function Base.:*(p::ComplexBasedRealIFFTPlan{T,N,R}, X::oneAPI.oneArray{T}) wher
end
# Inverse plan for complex-based real FFT plans
-function plan_inv(p::ComplexBasedRealFFTPlan{T,N,R}) where {T,N,R}
+function plan_inv(p::ComplexBasedRealFFTPlan{T, N, R}) where {T, N, R}
# For real FFT inverse, we need plan_brfft functionality
# The first dimension in the region should be the one that was reduced
first_dim = minimum(p.region)
@@ -443,18 +454,17 @@ function plan_inv(p::ComplexBasedRealFFTPlan{T,N,R}) where {T,N,R}
# Create inverse plan using our new multi-dimensional brfft
brfft_plan = _plan_brfft_nd(oneAPI.oneArray{Complex{T}}(undef, p.osz), d, p.region)
- ScaledPlan(brfft_plan, 1/normalization_factor(p.sz, p.region))
+ return ScaledPlan(brfft_plan, 1 / normalization_factor(p.sz, p.region))
end
# Inverse plan for complex-based real inverse FFT plans
-function plan_inv(p::ComplexBasedRealIFFTPlan{T,N,R}) where {T,N,R}
+function plan_inv(p::ComplexBasedRealIFFTPlan{T, N, R}) where {T, N, R}
# Create forward plan
forward_plan = _plan_rfft_nd(oneAPI.oneArray{real(T)}(undef, p.osz), p.region)
- ScaledPlan(forward_plan, 1/normalization_factor(p.osz, p.region))
+ return ScaledPlan(forward_plan, 1 / normalization_factor(p.osz, p.region))
end
-
# Convenience no-region methods use all dimensions in order
plan_fft(X::oneAPI.oneArray) = plan_fft(X, ntuple(identity, ndims(X)))
plan_bfft(X::oneAPI.oneArray) = plan_bfft(X, ntuple(identity, ndims(X)))
@@ -467,111 +477,119 @@ plan_brfft(X::oneAPI.oneArray, d::Integer) = plan_brfft(X, d, ntuple(identity, n
const plan_ifft = plan_bfft
const plan_ifft! = plan_bfft!
# plan_irfft should be normalized, unlike plan_brfft
-plan_irfft(X::oneAPI.oneArray{T,N}, d::Integer, region) where {T,N} = begin
+plan_irfft(X::oneAPI.oneArray{T, N}, d::Integer, region) where {T, N} = begin
p = plan_brfft(X, d, region)
- ScaledPlan(p, 1/normalization_factor(p.sz, p.region))
+ ScaledPlan(p, 1 / normalization_factor(p.sz, p.region))
end
-plan_irfft(X::oneAPI.oneArray{T,N}, d::Integer) where {T,N} = plan_irfft(X, d, (1,))
+plan_irfft(X::oneAPI.oneArray{T, N}, d::Integer) where {T, N} = plan_irfft(X, d, (1,))
# Inversion
Base.inv(p::MKLFFTPlan) = plan_inv(p)
# High-level wrappers operating like CPU FFTW versions.
-function fft(X::oneAPI.oneArray{T}) where {T<:Union{ComplexF32,ComplexF64}}
- (plan_fft(X) * X)
+function fft(X::oneAPI.oneArray{T}) where {T <: Union{ComplexF32, ComplexF64}}
+ return (plan_fft(X) * X)
end
-function ifft(X::oneAPI.oneArray{T}) where {T<:Union{ComplexF32,ComplexF64}}
+function ifft(X::oneAPI.oneArray{T}) where {T <: Union{ComplexF32, ComplexF64}}
p = plan_bfft(X)
# Apply normalization for ifft (unlike bfft which is unnormalized)
scaling = one(T) / normalization_factor(size(X), ntuple(identity, ndims(X)))
- scaling * (p * X)
+ return scaling * (p * X)
end
-function fft!(X::oneAPI.oneArray{T}) where {T<:Union{ComplexF32,ComplexF64}}
+function fft!(X::oneAPI.oneArray{T}) where {T <: Union{ComplexF32, ComplexF64}}
(plan_fft!(X) * X; X)
end
-function ifft!(X::oneAPI.oneArray{T}) where {T<:Union{ComplexF32,ComplexF64}}
+function ifft!(X::oneAPI.oneArray{T}) where {T <: Union{ComplexF32, ComplexF64}}
p = plan_bfft!(X)
# Apply normalization for ifft! (unlike bfft! which is unnormalized)
scaling = one(T) / normalization_factor(size(X), ntuple(identity, ndims(X)))
p * X
X .*= scaling
- X
+ return X
end
-function rfft(X::oneAPI.oneArray{T}) where {T<:Union{Float32,Float64}}
- (plan_rfft(X) * X)
+function rfft(X::oneAPI.oneArray{T}) where {T <: Union{Float32, Float64}}
+ return (plan_rfft(X) * X)
end
-function irfft(X::oneAPI.oneArray{T}, d::Integer) where {T<:Union{ComplexF32,ComplexF64}}
+function irfft(X::oneAPI.oneArray{T}, d::Integer) where {T <: Union{ComplexF32, ComplexF64}}
# Use the normalized plan_irfft instead of unnormalized plan_brfft
- (plan_irfft(X, d) * X)
+ return (plan_irfft(X, d) * X)
end
# Execution helpers
-_rawptr(a::oneAPI.oneArray{T}) where T = reinterpret(Ptr{Cvoid}, pointer(a))
+_rawptr(a::oneAPI.oneArray{T}) where {T} = reinterpret(Ptr{Cvoid}, pointer(a))
-function _exec!(p::cMKLFFTPlan{T,MKLFFT_FORWARD,true}, X::oneAPI.oneArray{T}) where T
- st = onemklDftComputeForward(p.handle, _rawptr(X)); st==0 || error("forward FFT failed ($st)"); X
+function _exec!(p::cMKLFFTPlan{T, MKLFFT_FORWARD, true}, X::oneAPI.oneArray{T}) where {T}
+ st = onemklDftComputeForward(p.handle, _rawptr(X)); st == 0 || error("forward FFT failed ($st)")
+ return X
end
-function _exec!(p::cMKLFFTPlan{T,MKLFFT_INVERSE,true}, X::oneAPI.oneArray{T}) where T
- st = onemklDftComputeBackward(p.handle, _rawptr(X)); st==0 || error("inverse FFT failed ($st)"); X
+function _exec!(p::cMKLFFTPlan{T, MKLFFT_INVERSE, true}, X::oneAPI.oneArray{T}) where {T}
+ st = onemklDftComputeBackward(p.handle, _rawptr(X)); st == 0 || error("inverse FFT failed ($st)")
+ return X
end
-function _exec!(p::cMKLFFTPlan{T,K,false}, X::oneAPI.oneArray{T}, Y::oneAPI.oneArray{T}) where {T,K}
- st = (K==MKLFFT_FORWARD ? onemklDftComputeForwardOutOfPlace : onemklDftComputeBackwardOutOfPlace)(p.handle, _rawptr(X), _rawptr(Y)); st==0 || error("FFT failed ($st)"); Y
+function _exec!(p::cMKLFFTPlan{T, K, false}, X::oneAPI.oneArray{T}, Y::oneAPI.oneArray{T}) where {T, K}
+ st = (K == MKLFFT_FORWARD ? onemklDftComputeForwardOutOfPlace : onemklDftComputeBackwardOutOfPlace)(p.handle, _rawptr(X), _rawptr(Y)); st == 0 || error("FFT failed ($st)")
+ return Y
end
# Real forward
-function _exec!(p::rMKLFFTPlan{T,MKLFFT_FORWARD,false}, X::oneAPI.oneArray{T}, Y::oneAPI.oneArray{Complex{T}}) where T
- st = onemklDftComputeForwardOutOfPlace(p.handle, _rawptr(X), _rawptr(Y)); st==0 || error("rfft failed ($st)"); Y
+function _exec!(p::rMKLFFTPlan{T, MKLFFT_FORWARD, false}, X::oneAPI.oneArray{T}, Y::oneAPI.oneArray{Complex{T}}) where {T}
+ st = onemklDftComputeForwardOutOfPlace(p.handle, _rawptr(X), _rawptr(Y)); st == 0 || error("rfft failed ($st)")
+ return Y
end
# Real inverse (complex -> real)
-function _exec!(p::rMKLFFTPlan{T,MKLFFT_INVERSE,false}, X::oneAPI.oneArray{T}, Y::oneAPI.oneArray{R}) where {R,T<:Complex{R}}
- st = onemklDftComputeBackwardOutOfPlace(p.handle, _rawptr(X), _rawptr(Y)); st==0 || error("brfft failed ($st)"); Y
+function _exec!(p::rMKLFFTPlan{T, MKLFFT_INVERSE, false}, X::oneAPI.oneArray{T}, Y::oneAPI.oneArray{R}) where {R, T <: Complex{R}}
+ st = onemklDftComputeBackwardOutOfPlace(p.handle, _rawptr(X), _rawptr(Y)); st == 0 || error("brfft failed ($st)")
+ return Y
end
# Public API similar to AMDGPU
-function Base.:*(p::cMKLFFTPlan{T,K,true}, X::oneAPI.oneArray{T}) where {T,K}
- _exec!(p,X)
+function Base.:*(p::cMKLFFTPlan{T, K, true}, X::oneAPI.oneArray{T}) where {T, K}
+ return _exec!(p, X)
end
-function Base.:*(p::cMKLFFTPlan{T,K,false}, X::oneAPI.oneArray{T}) where {T,K}
- Y = oneAPI.oneArray{T}(undef, p.osz); _exec!(p,X,Y)
+function Base.:*(p::cMKLFFTPlan{T, K, false}, X::oneAPI.oneArray{T}) where {T, K}
+ Y = oneAPI.oneArray{T}(undef, p.osz)
+ return _exec!(p, X, Y)
end
-function LinearAlgebra.mul!(Y::oneAPI.oneArray{T}, p::cMKLFFTPlan{T,K,false}, X::oneAPI.oneArray{T}) where {T,K}
- _exec!(p,X,Y)
+function LinearAlgebra.mul!(Y::oneAPI.oneArray{T}, p::cMKLFFTPlan{T, K, false}, X::oneAPI.oneArray{T}) where {T, K}
+ return _exec!(p, X, Y)
end
# Real forward
-function Base.:*(p::rMKLFFTPlan{T,MKLFFT_FORWARD,false}, X::oneAPI.oneArray{T}) where {T<:Union{Float32,Float64}}
- Y = oneAPI.oneArray{Complex{T}}(undef, p.osz); _exec!(p,X,Y)
+function Base.:*(p::rMKLFFTPlan{T, MKLFFT_FORWARD, false}, X::oneAPI.oneArray{T}) where {T <: Union{Float32, Float64}}
+ Y = oneAPI.oneArray{Complex{T}}(undef, p.osz)
+ return _exec!(p, X, Y)
end
-function LinearAlgebra.mul!(Y::oneAPI.oneArray{Complex{T}}, p::rMKLFFTPlan{T,MKLFFT_FORWARD,false}, X::oneAPI.oneArray{T}) where {T<:Union{Float32,Float64}}
- _exec!(p,X,Y)
+function LinearAlgebra.mul!(Y::oneAPI.oneArray{Complex{T}}, p::rMKLFFTPlan{T, MKLFFT_FORWARD, false}, X::oneAPI.oneArray{T}) where {T <: Union{Float32, Float64}}
+ return _exec!(p, X, Y)
end
# Real inverse
-function Base.:*(p::rMKLFFTPlan{T,MKLFFT_INVERSE,false}, X::oneAPI.oneArray{T}) where {R,T<:Complex{R}}
- Y = oneAPI.oneArray{R}(undef, p.osz); _exec!(p,X,Y)
+function Base.:*(p::rMKLFFTPlan{T, MKLFFT_INVERSE, false}, X::oneAPI.oneArray{T}) where {R, T <: Complex{R}}
+ Y = oneAPI.oneArray{R}(undef, p.osz)
+ return _exec!(p, X, Y)
end
-function LinearAlgebra.mul!(Y::oneAPI.oneArray{R}, p::rMKLFFTPlan{T,MKLFFT_INVERSE,false}, X::oneAPI.oneArray{T}) where {R,T<:Complex{R}}
- _exec!(p,X,Y)
+function LinearAlgebra.mul!(Y::oneAPI.oneArray{R}, p::rMKLFFTPlan{T, MKLFFT_INVERSE, false}, X::oneAPI.oneArray{T}) where {R, T <: Complex{R}}
+ return _exec!(p, X, Y)
end
# Support for applying complex plans to real arrays (convert real to complex first)
-function Base.:*(p::cMKLFFTPlan{T,K,false}, X::oneAPI.oneArray{R}) where {T,K,R<:Union{Float32,Float64}}
+function Base.:*(p::cMKLFFTPlan{T, K, false}, X::oneAPI.oneArray{R}) where {T, K, R <: Union{Float32, Float64}}
# Only allow if T is the complex version of R
if T != Complex{R}
error("Type mismatch: plan expects $(T) but got $(R)")
end
# Convert real input to complex
X_complex = complex.(X)
- p * X_complex
+ return p * X_complex
end
-function LinearAlgebra.mul!(Y::oneAPI.oneArray{T}, p::cMKLFFTPlan{T,K,false}, X::oneAPI.oneArray{R}) where {T,K,R<:Union{Float32,Float64}}
+function LinearAlgebra.mul!(Y::oneAPI.oneArray{T}, p::cMKLFFTPlan{T, K, false}, X::oneAPI.oneArray{R}) where {T, K, R <: Union{Float32, Float64}}
# Only allow if T is the complex version of R
if T != Complex{R}
error("Type mismatch: plan expects $(T) but got $(R)")
end
# Convert real input to complex
X_complex = complex.(X)
- _exec!(p, X_complex, Y)
+ return _exec!(p, X_complex, Y)
end
end # module FFT
diff --git a/lib/support/liboneapi_support.jl b/lib/support/liboneapi_support.jl
index 06d8bee..0ea694b 100644
--- a/lib/support/liboneapi_support.jl
+++ b/lib/support/liboneapi_support.jl
@@ -7111,122 +7111,160 @@ mutable struct onemklDftDescriptor_st end
const onemklDftDescriptor_t = Ptr{onemklDftDescriptor_st}
function onemklDftCreate1D(desc, precision, domain, length)
- @ccall liboneapi_support.onemklDftCreate1D(desc::Ptr{onemklDftDescriptor_t},
- precision::onemklDftPrecision,
- domain::onemklDftDomain, length::Int64)::Cint
+ return @ccall liboneapi_support.onemklDftCreate1D(
+ desc::Ptr{onemklDftDescriptor_t},
+ precision::onemklDftPrecision,
+ domain::onemklDftDomain, length::Int64
+ )::Cint
end
function onemklDftCreateND(desc, precision, domain, dim, lengths)
- @ccall liboneapi_support.onemklDftCreateND(desc::Ptr{onemklDftDescriptor_t},
- precision::onemklDftPrecision,
- domain::onemklDftDomain, dim::Int64,
- lengths::Ptr{Int64})::Cint
+ return @ccall liboneapi_support.onemklDftCreateND(
+ desc::Ptr{onemklDftDescriptor_t},
+ precision::onemklDftPrecision,
+ domain::onemklDftDomain, dim::Int64,
+ lengths::Ptr{Int64}
+ )::Cint
end
function onemklDftDestroy(desc)
- @ccall liboneapi_support.onemklDftDestroy(desc::onemklDftDescriptor_t)::Cint
+ return @ccall liboneapi_support.onemklDftDestroy(desc::onemklDftDescriptor_t)::Cint
end
function onemklDftCommit(desc, queue)
- @ccall liboneapi_support.onemklDftCommit(desc::onemklDftDescriptor_t,
- queue::syclQueue_t)::Cint
+ return @ccall liboneapi_support.onemklDftCommit(
+ desc::onemklDftDescriptor_t,
+ queue::syclQueue_t
+ )::Cint
end
function onemklDftSetValueInt64(desc, param, value)
- @ccall liboneapi_support.onemklDftSetValueInt64(desc::onemklDftDescriptor_t,
- param::onemklDftConfigParam,
- value::Int64)::Cint
+ return @ccall liboneapi_support.onemklDftSetValueInt64(
+ desc::onemklDftDescriptor_t,
+ param::onemklDftConfigParam,
+ value::Int64
+ )::Cint
end
function onemklDftSetValueDouble(desc, param, value)
- @ccall liboneapi_support.onemklDftSetValueDouble(desc::onemklDftDescriptor_t,
- param::onemklDftConfigParam,
- value::Cdouble)::Cint
+ return @ccall liboneapi_support.onemklDftSetValueDouble(
+ desc::onemklDftDescriptor_t,
+ param::onemklDftConfigParam,
+ value::Cdouble
+ )::Cint
end
function onemklDftSetValueInt64Array(desc, param, values, n)
- @ccall liboneapi_support.onemklDftSetValueInt64Array(desc::onemklDftDescriptor_t,
- param::onemklDftConfigParam,
- values::Ptr{Int64}, n::Int64)::Cint
+ return @ccall liboneapi_support.onemklDftSetValueInt64Array(
+ desc::onemklDftDescriptor_t,
+ param::onemklDftConfigParam,
+ values::Ptr{Int64}, n::Int64
+ )::Cint
end
function onemklDftSetValueConfigValue(desc, param, value)
- @ccall liboneapi_support.onemklDftSetValueConfigValue(desc::onemklDftDescriptor_t,
- param::onemklDftConfigParam,
- value::onemklDftConfigValue)::Cint
+ return @ccall liboneapi_support.onemklDftSetValueConfigValue(
+ desc::onemklDftDescriptor_t,
+ param::onemklDftConfigParam,
+ value::onemklDftConfigValue
+ )::Cint
end
function onemklDftGetValueInt64(desc, param, value)
- @ccall liboneapi_support.onemklDftGetValueInt64(desc::onemklDftDescriptor_t,
- param::onemklDftConfigParam,
- value::Ptr{Int64})::Cint
+ return @ccall liboneapi_support.onemklDftGetValueInt64(
+ desc::onemklDftDescriptor_t,
+ param::onemklDftConfigParam,
+ value::Ptr{Int64}
+ )::Cint
end
function onemklDftGetValueDouble(desc, param, value)
- @ccall liboneapi_support.onemklDftGetValueDouble(desc::onemklDftDescriptor_t,
- param::onemklDftConfigParam,
- value::Ptr{Cdouble})::Cint
+ return @ccall liboneapi_support.onemklDftGetValueDouble(
+ desc::onemklDftDescriptor_t,
+ param::onemklDftConfigParam,
+ value::Ptr{Cdouble}
+ )::Cint
end
function onemklDftGetValueInt64Array(desc, param, values, n)
- @ccall liboneapi_support.onemklDftGetValueInt64Array(desc::onemklDftDescriptor_t,
- param::onemklDftConfigParam,
- values::Ptr{Int64},
- n::Ptr{Int64})::Cint
+ return @ccall liboneapi_support.onemklDftGetValueInt64Array(
+ desc::onemklDftDescriptor_t,
+ param::onemklDftConfigParam,
+ values::Ptr{Int64},
+ n::Ptr{Int64}
+ )::Cint
end
function onemklDftGetValueConfigValue(desc, param, value)
- @ccall liboneapi_support.onemklDftGetValueConfigValue(desc::onemklDftDescriptor_t,
- param::onemklDftConfigParam,
- value::Ptr{onemklDftConfigValue})::Cint
+ return @ccall liboneapi_support.onemklDftGetValueConfigValue(
+ desc::onemklDftDescriptor_t,
+ param::onemklDftConfigParam,
+ value::Ptr{onemklDftConfigValue}
+ )::Cint
end
function onemklDftComputeForward(desc, inout)
- @ccall liboneapi_support.onemklDftComputeForward(desc::onemklDftDescriptor_t,
- inout::Ptr{Cvoid})::Cint
+ return @ccall liboneapi_support.onemklDftComputeForward(
+ desc::onemklDftDescriptor_t,
+ inout::Ptr{Cvoid}
+ )::Cint
end
function onemklDftComputeForwardOutOfPlace(desc, in, out)
- @ccall liboneapi_support.onemklDftComputeForwardOutOfPlace(desc::onemklDftDescriptor_t,
- in::Ptr{Cvoid},
- out::Ptr{Cvoid})::Cint
+ return @ccall liboneapi_support.onemklDftComputeForwardOutOfPlace(
+ desc::onemklDftDescriptor_t,
+ in::Ptr{Cvoid},
+ out::Ptr{Cvoid}
+ )::Cint
end
function onemklDftComputeBackward(desc, inout)
- @ccall liboneapi_support.onemklDftComputeBackward(desc::onemklDftDescriptor_t,
- inout::Ptr{Cvoid})::Cint
+ return @ccall liboneapi_support.onemklDftComputeBackward(
+ desc::onemklDftDescriptor_t,
+ inout::Ptr{Cvoid}
+ )::Cint
end
function onemklDftComputeBackwardOutOfPlace(desc, in, out)
- @ccall liboneapi_support.onemklDftComputeBackwardOutOfPlace(desc::onemklDftDescriptor_t,
- in::Ptr{Cvoid},
- out::Ptr{Cvoid})::Cint
+ return @ccall liboneapi_support.onemklDftComputeBackwardOutOfPlace(
+ desc::onemklDftDescriptor_t,
+ in::Ptr{Cvoid},
+ out::Ptr{Cvoid}
+ )::Cint
end
function onemklDftComputeForwardBuffer(desc, inout)
- @ccall liboneapi_support.onemklDftComputeForwardBuffer(desc::onemklDftDescriptor_t,
- inout::Ptr{Cvoid})::Cint
+ return @ccall liboneapi_support.onemklDftComputeForwardBuffer(
+ desc::onemklDftDescriptor_t,
+ inout::Ptr{Cvoid}
+ )::Cint
end
function onemklDftComputeForwardOutOfPlaceBuffer(desc, in, out)
- @ccall liboneapi_support.onemklDftComputeForwardOutOfPlaceBuffer(desc::onemklDftDescriptor_t,
- in::Ptr{Cvoid},
- out::Ptr{Cvoid})::Cint
+ return @ccall liboneapi_support.onemklDftComputeForwardOutOfPlaceBuffer(
+ desc::onemklDftDescriptor_t,
+ in::Ptr{Cvoid},
+ out::Ptr{Cvoid}
+ )::Cint
end
function onemklDftComputeBackwardBuffer(desc, inout)
- @ccall liboneapi_support.onemklDftComputeBackwardBuffer(desc::onemklDftDescriptor_t,
- inout::Ptr{Cvoid})::Cint
+ return @ccall liboneapi_support.onemklDftComputeBackwardBuffer(
+ desc::onemklDftDescriptor_t,
+ inout::Ptr{Cvoid}
+ )::Cint
end
function onemklDftComputeBackwardOutOfPlaceBuffer(desc, in, out)
- @ccall liboneapi_support.onemklDftComputeBackwardOutOfPlaceBuffer(desc::onemklDftDescriptor_t,
- in::Ptr{Cvoid},
- out::Ptr{Cvoid})::Cint
+ return @ccall liboneapi_support.onemklDftComputeBackwardOutOfPlaceBuffer(
+ desc::onemklDftDescriptor_t,
+ in::Ptr{Cvoid},
+ out::Ptr{Cvoid}
+ )::Cint
end
function onemklDftQueryParamIndices(out, n)
- @ccall liboneapi_support.onemklDftQueryParamIndices(out::Ptr{Int64}, n::Int64)::Cint
+ return @ccall liboneapi_support.onemklDftQueryParamIndices(out::Ptr{Int64}, n::Int64)::Cint
end
const ONEMKL_DFT_STATUS_SUCCESS = 0
diff --git a/res/wrap.jl b/res/wrap.jl
index 1d48315..2e9b29f 100644
--- a/res/wrap.jl
+++ b/res/wrap.jl
@@ -112,14 +112,14 @@ using oneAPI_Level_Zero_Headers_jll
function main()
wrap("ze", oneAPI_Level_Zero_Headers_jll.ze_api)
- wrap(
- "support",
- joinpath(dirname(@__DIR__), "deps", "src", "sycl.h"),
- joinpath(dirname(@__DIR__), "deps", "src", "onemkl.h"),
- joinpath(dirname(@__DIR__), "deps", "src", "onemkl_dft.h");
- dependents=false,
- include_dirs=[dirname(dirname(oneAPI_Level_Zero_Headers_jll.ze_api))]
- )
+ return wrap(
+ "support",
+ joinpath(dirname(@__DIR__), "deps", "src", "sycl.h"),
+ joinpath(dirname(@__DIR__), "deps", "src", "onemkl.h"),
+ joinpath(dirname(@__DIR__), "deps", "src", "onemkl_dft.h");
+ dependents = false,
+ include_dirs = [dirname(dirname(oneAPI_Level_Zero_Headers_jll.ze_api))]
+ )
end
isinteractive() || main()
diff --git a/test/fft.jl b/test/fft.jl
index 1b148df..ef81c21 100644
--- a/test/fft.jl
+++ b/test/fft.jl
@@ -7,39 +7,39 @@ using Random
Random.seed!(1234)
# Helper to move data to GPU
-gpu(A::AbstractArray{T}) where T = oneAPI.oneArray{T}(A)
+gpu(A::AbstractArray{T}) where {T} = oneAPI.oneArray{T}(A)
struct _Plan end
struct _FFT end
-const MYRTOL = 1e-5
-const MYATOL = 1e-8
+const MYRTOL = 1.0e-5
+const MYATOL = 1.0e-8
-function cmp(a,b; rtol=MYRTOL, atol=MYATOL)
- @test isapprox(Array(a), Array(b); rtol=rtol, atol=atol)
+function cmp(a, b; rtol = MYRTOL, atol = MYATOL)
+ return @test isapprox(Array(a), Array(b); rtol = rtol, atol = atol)
end
-function test_plan(::_Plan, plan, X::AbstractArray{T,N}) where {T,N}
+function test_plan(::_Plan, plan, X::AbstractArray{T, N}) where {T, N}
p = plan(X)
Y = p * X
return Y
end
-function test_plan(::_FFT, f, X::AbstractArray{T,N}) where {T,N}
+function test_plan(::_FFT, f, X::AbstractArray{T, N}) where {T, N}
Y = if f === AbstractFFTs.irfft || f === AbstractFFTs.brfft
- f(X, size(X, ndims(X))*2 - 2)
+ f(X, size(X, ndims(X)) * 2 - 2)
else
f(X)
end
return Y
end
-function test_plan(t, plan::Function, dim::Tuple, T::Type, iplan=nothing)
+function test_plan(t, plan::Function, dim::Tuple, T::Type, iplan = nothing)
X = rand(T, dim)
dX = gpu(X)
Y = test_plan(t, plan, X)
dY = test_plan(t, plan, dX)
cmp(dY, Y)
- if iplan !== nothing
+ return if iplan !== nothing
iX = test_plan(t, iplan, Y)
idX = test_plan(t, iplan, dY)
cmp(idX, iX)
@@ -47,36 +47,36 @@ function test_plan(t, plan::Function, dim::Tuple, T::Type, iplan=nothing)
end
@testset "FFT" begin
-@testset "$(length(dim))D" for dim in [(8,), (8,32), (8,32,64)]
- test_plan(_Plan(), AbstractFFTs.plan_fft, dim, ComplexF32, AbstractFFTs.plan_ifft)
- test_plan(_Plan(), AbstractFFTs.plan_fft, dim, ComplexF32, AbstractFFTs.plan_bfft)
- test_plan(_Plan(), AbstractFFTs.plan_fft, dim, Float32, AbstractFFTs.plan_ifft)
- test_plan(_Plan(), AbstractFFTs.plan_fft, dim, Float32, AbstractFFTs.plan_bfft)
- test_plan(_Plan(), AbstractFFTs.plan_rfft, dim, Float32)
- test_plan(_Plan(), AbstractFFTs.plan_fft!, dim, ComplexF32, AbstractFFTs.plan_bfft!)
- # Not part of FFTW
- # test_plan(AbstractFFTs.plan_rfft!, Float32)
- test_plan(_FFT(), AbstractFFTs.fft, dim, ComplexF32, AbstractFFTs.ifft)
- test_plan(_FFT(), AbstractFFTs.fft, dim, ComplexF32, AbstractFFTs.bfft)
- if length(dim) == 1 # irfft/brfft only for 1D
- test_plan(_FFT(), AbstractFFTs.rfft, dim, Float32, AbstractFFTs.irfft)
- test_plan(_FFT(), AbstractFFTs.rfft, dim, Float32, AbstractFFTs.brfft)
- end
- if (ComplexF64 in eltypes) && (Float64 in eltypes)
- test_plan(_Plan(), AbstractFFTs.plan_fft, dim, ComplexF64, AbstractFFTs.plan_ifft)
- test_plan(_Plan(), AbstractFFTs.plan_fft, dim, ComplexF64, AbstractFFTs.plan_bfft)
- test_plan(_Plan(), AbstractFFTs.plan_fft, dim, Float64, AbstractFFTs.plan_ifft)
- test_plan(_Plan(), AbstractFFTs.plan_fft, dim, Float64, AbstractFFTs.plan_bfft)
- test_plan(_Plan(), AbstractFFTs.plan_rfft, dim, Float64)
- test_plan(_Plan(), AbstractFFTs.plan_fft!, dim, ComplexF64, AbstractFFTs.plan_bfft!)
+ @testset "$(length(dim))D" for dim in [(8,), (8, 32), (8, 32, 64)]
+ test_plan(_Plan(), AbstractFFTs.plan_fft, dim, ComplexF32, AbstractFFTs.plan_ifft)
+ test_plan(_Plan(), AbstractFFTs.plan_fft, dim, ComplexF32, AbstractFFTs.plan_bfft)
+ test_plan(_Plan(), AbstractFFTs.plan_fft, dim, Float32, AbstractFFTs.plan_ifft)
+ test_plan(_Plan(), AbstractFFTs.plan_fft, dim, Float32, AbstractFFTs.plan_bfft)
+ test_plan(_Plan(), AbstractFFTs.plan_rfft, dim, Float32)
+ test_plan(_Plan(), AbstractFFTs.plan_fft!, dim, ComplexF32, AbstractFFTs.plan_bfft!)
# Not part of FFTW
- # test_plan(AbstractFFTs.plan_rfft!, Float64)
- test_plan(_FFT(), AbstractFFTs.fft, dim, ComplexF64, AbstractFFTs.ifft)
- test_plan(_FFT(), AbstractFFTs.fft, dim, ComplexF64, AbstractFFTs.bfft)
+ # test_plan(AbstractFFTs.plan_rfft!, Float32)
+ test_plan(_FFT(), AbstractFFTs.fft, dim, ComplexF32, AbstractFFTs.ifft)
+ test_plan(_FFT(), AbstractFFTs.fft, dim, ComplexF32, AbstractFFTs.bfft)
if length(dim) == 1 # irfft/brfft only for 1D
- test_plan(_FFT(), AbstractFFTs.rfft, dim, Float64, AbstractFFTs.irfft)
- test_plan(_FFT(), AbstractFFTs.rfft, dim, Float64, AbstractFFTs.brfft)
+ test_plan(_FFT(), AbstractFFTs.rfft, dim, Float32, AbstractFFTs.irfft)
+ test_plan(_FFT(), AbstractFFTs.rfft, dim, Float32, AbstractFFTs.brfft)
+ end
+ if (ComplexF64 in eltypes) && (Float64 in eltypes)
+ test_plan(_Plan(), AbstractFFTs.plan_fft, dim, ComplexF64, AbstractFFTs.plan_ifft)
+ test_plan(_Plan(), AbstractFFTs.plan_fft, dim, ComplexF64, AbstractFFTs.plan_bfft)
+ test_plan(_Plan(), AbstractFFTs.plan_fft, dim, Float64, AbstractFFTs.plan_ifft)
+ test_plan(_Plan(), AbstractFFTs.plan_fft, dim, Float64, AbstractFFTs.plan_bfft)
+ test_plan(_Plan(), AbstractFFTs.plan_rfft, dim, Float64)
+ test_plan(_Plan(), AbstractFFTs.plan_fft!, dim, ComplexF64, AbstractFFTs.plan_bfft!)
+ # Not part of FFTW
+ # test_plan(AbstractFFTs.plan_rfft!, Float64)
+ test_plan(_FFT(), AbstractFFTs.fft, dim, ComplexF64, AbstractFFTs.ifft)
+ test_plan(_FFT(), AbstractFFTs.fft, dim, ComplexF64, AbstractFFTs.bfft)
+ if length(dim) == 1 # irfft/brfft only for 1D
+ test_plan(_FFT(), AbstractFFTs.rfft, dim, Float64, AbstractFFTs.irfft)
+ test_plan(_FFT(), AbstractFFTs.rfft, dim, Float64, AbstractFFTs.brfft)
+ end
end
end
end
-end |
| } | ||
| *out = desc; | ||
| return 0; | ||
| } catch (...) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's a catch-all.
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #515 +/- ##
==========================================
- Coverage 81.73% 79.70% -2.04%
==========================================
Files 44 45 +1
Lines 2540 2818 +278
==========================================
+ Hits 2076 2246 +170
- Misses 464 572 +108 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
lib/mkl/fft.jl
Outdated
| ccall_create1d(desc_ref, prec::Int32, dom::Int32, length::Int64) = ccall((:onemklDftCreate1D, lib), Cint, (Ref{Ptr{Cvoid}}, Cint, Cint, Int64), desc_ref, prec, dom, length) | ||
| ccall_creatend(desc_ref, prec::Int32, dom::Int32, dim::Int64, lengths::Ptr{Int64}) = ccall((:onemklDftCreateND, lib), Cint, (Ref{Ptr{Cvoid}}, Cint, Cint, Int64, Ptr{Int64}), desc_ref, prec, dom, dim, lengths) | ||
| ccall_destroy(desc) = ccall((:onemklDftDestroy, lib), Cint, (Ptr{Cvoid},), desc) | ||
| ccall_commit(desc, q) = ccall((:onemklDftCommit, lib), Cint, (Ptr{Cvoid}, syclQueue_t), desc, q) | ||
| ccall_fwd(desc, ptr) = ccall((:onemklDftComputeForward, lib), Cint, (Ptr{Cvoid}, Ptr{Cvoid}), desc, ptr) | ||
| ccall_fwd_oop(desc, pin, pout) = ccall((:onemklDftComputeForwardOutOfPlace, lib), Cint, (Ptr{Cvoid}, Ptr{Cvoid}, Ptr{Cvoid}), desc, pin, pout) | ||
| ccall_bwd(desc, ptr) = ccall((:onemklDftComputeBackward, lib), Cint, (Ptr{Cvoid}, Ptr{Cvoid}), desc, ptr) | ||
| ccall_bwd_oop(desc, pin, pout) = ccall((:onemklDftComputeBackwardOutOfPlace, lib), Cint, (Ptr{Cvoid}, Ptr{Cvoid}, Ptr{Cvoid}), desc, pin, pout) | ||
| ccall_set_double(desc, param::Int32, value::Float64) = ccall((:onemklDftSetValueDouble, lib), Cint, (Ptr{Cvoid}, Cint, Float64), desc, param, value) | ||
| ccall_set_int(desc, param::Int32, value::Int64) = ccall((:onemklDftSetValueInt64, lib), Cint, (Ptr{Cvoid}, Cint, Int64), desc, param, value) | ||
| ccall_set_int64_array(desc, param::Int32, values::Vector{Int64}) = ccall((:onemklDftSetValueInt64Array, lib), Cint, (Ptr{Cvoid}, Cint, Ptr{Int64}, Int64), desc, param, pointer(values), length(values)) | ||
| ccall_set_cfg(desc, param::Int32, value::Int32) = ccall((:onemklDftSetValueConfigValue, lib), Cint, (Ptr{Cvoid}, Cint, Cint), desc, param, value) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@michel2323 Please use the wrappers generated by Clang.jl.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is generated with Clang. I'm not sure I understand.
Line 119 in 7822c44
| joinpath(dirname(@__DIR__), "deps", "src", "onemkl_dft.h"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see now.
lib/mkl/fft.jl
Outdated
| R = length(region); reg = NTuple{R,Int}(region) | ||
| # Only support single dimension transforms for now | ||
| if R != 1 | ||
| error("Multi-dimensional real FFT not yet supported") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@michel2323 Do we know if it is feature not yet implemented by Intel or something to improve on our side?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I get wrong values. Maybe I did something wrong, or the Intel library is wrong.
amontoison
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@michel2323 Please address the two comments and it should be good for me.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM (if all tests passed)!
It passes some tests. GPT 5 helped quite a bit here.