Skip to content

Conversation

@michel2323
Copy link
Member

It passes some tests. GPT 5 helped quite a bit here.

@github-actions
Copy link
Contributor

github-actions bot commented Aug 19, 2025

Your PR requires formatting changes to meet the project's style guidelines.
Please consider running Runic (git runic master) to apply these changes.

Click here to view the suggested changes.
diff --git a/lib/mkl/fft.jl b/lib/mkl/fft.jl
index 5f5614b..9d38e11 100644
--- a/lib/mkl/fft.jl
+++ b/lib/mkl/fft.jl
@@ -22,34 +22,34 @@ using ..Support
 # Allow implicit conversion of SYCL queue object to raw handle when storing/passing
 Base.convert(::Type{syclQueue_t}, q::SYCL.syclQueue) = Base.unsafe_convert(syclQueue_t, q)
 
-abstract type MKLFFTPlan{T,K,inplace} <: AbstractFFTs.Plan{T} end
+abstract type MKLFFTPlan{T, K, inplace} <: AbstractFFTs.Plan{T} end
 
-Base.eltype(::MKLFFTPlan{T}) where T = T
-is_inplace(::MKLFFTPlan{<:Any,<:Any,inplace}) where inplace = inplace
+Base.eltype(::MKLFFTPlan{T}) where {T} = T
+is_inplace(::MKLFFTPlan{<:Any, <:Any, inplace}) where {inplace} = inplace
 
 # Forward / inverse flags
 const MKLFFT_FORWARD = true
 const MKLFFT_INVERSE = false
 
-mutable struct cMKLFFTPlan{T,K,inplace,N,R,B} <: MKLFFTPlan{T,K,inplace}
+mutable struct cMKLFFTPlan{T, K, inplace, N, R, B} <: MKLFFTPlan{T, K, inplace}
     handle::onemklDftDescriptor_t
     queue::syclQueue_t
-    sz::NTuple{N,Int}
-    osz::NTuple{N,Int}
+    sz::NTuple{N, Int}
+    osz::NTuple{N, Int}
     realdomain::Bool
-    region::NTuple{R,Int}
+    region::NTuple{R, Int}
     buffer::B
     pinv::Any
 end
 
 # Real transforms use separate struct (mirroring AMDGPU style) for buffer staging
-mutable struct rMKLFFTPlan{T,K,inplace,N,R,B} <: MKLFFTPlan{T,K,inplace}
+mutable struct rMKLFFTPlan{T, K, inplace, N, R, B} <: MKLFFTPlan{T, K, inplace}
     handle::onemklDftDescriptor_t
     queue::syclQueue_t
-    sz::NTuple{N,Int}
-    osz::NTuple{N,Int}
+    sz::NTuple{N, Int}
+    osz::NTuple{N, Int}
     xtype::Symbol
-    region::NTuple{R,Int}
+    region::NTuple{R, Int}
     buffer::B
     pinv::Any
 end
@@ -57,40 +57,44 @@ end
 # Inverse plan constructors (derive from existing plan)
 function normalization_factor(sz, region)
     # AbstractFFTs expects inverse to scale by 1/prod(lengths along region)
-    prod(ntuple(i-> sz[region[i]], length(region)))
+    return prod(ntuple(i -> sz[region[i]], length(region)))
 end
 
-function plan_inv(p::cMKLFFTPlan{T,MKLFFT_FORWARD,inplace,N,R,B}) where {T,inplace,N,R,B}
-    q = cMKLFFTPlan{T,MKLFFT_INVERSE,inplace,N,R,B}(p.handle,p.queue,p.sz,p.osz,p.realdomain,p.region,p.buffer,p)
+function plan_inv(p::cMKLFFTPlan{T, MKLFFT_FORWARD, inplace, N, R, B}) where {T, inplace, N, R, B}
+    q = cMKLFFTPlan{T, MKLFFT_INVERSE, inplace, N, R, B}(p.handle, p.queue, p.sz, p.osz, p.realdomain, p.region, p.buffer, p)
     p.pinv = q
-    ScaledPlan(q, 1/normalization_factor(p.sz, p.region))
+    return ScaledPlan(q, 1 / normalization_factor(p.sz, p.region))
 end
-function plan_inv(p::cMKLFFTPlan{T,MKLFFT_INVERSE,inplace,N,R,B}) where {T,inplace,N,R,B}
-    q = cMKLFFTPlan{T,MKLFFT_FORWARD,inplace,N,R,B}(p.handle,p.queue,p.sz,p.osz,p.realdomain,p.region,p.buffer,p)
+function plan_inv(p::cMKLFFTPlan{T, MKLFFT_INVERSE, inplace, N, R, B}) where {T, inplace, N, R, B}
+    q = cMKLFFTPlan{T, MKLFFT_FORWARD, inplace, N, R, B}(p.handle, p.queue, p.sz, p.osz, p.realdomain, p.region, p.buffer, p)
     p.pinv = q
-    ScaledPlan(q, 1/normalization_factor(p.sz, p.region))
+    return ScaledPlan(q, 1 / normalization_factor(p.sz, p.region))
 end
 
-function plan_inv(p::rMKLFFTPlan{T,MKLFFT_FORWARD,inplace,N,R,B}) where {T,inplace,N,R,B}
-    q = rMKLFFTPlan{T,MKLFFT_INVERSE,inplace,N,R,B}(p.handle,p.queue,p.sz,p.osz,:brfft,p.region,p.buffer,p)
+function plan_inv(p::rMKLFFTPlan{T, MKLFFT_FORWARD, inplace, N, R, B}) where {T, inplace, N, R, B}
+    q = rMKLFFTPlan{T, MKLFFT_INVERSE, inplace, N, R, B}(p.handle, p.queue, p.sz, p.osz, :brfft, p.region, p.buffer, p)
     p.pinv = q
-    ScaledPlan(q, 1/normalization_factor(p.sz, p.region))
+    return ScaledPlan(q, 1 / normalization_factor(p.sz, p.region))
 end
-function plan_inv(p::rMKLFFTPlan{T,MKLFFT_INVERSE,inplace,N,R,B}) where {T,inplace,N,R,B}
-    q = rMKLFFTPlan{T,MKLFFT_FORWARD,inplace,N,R,B}(p.handle,p.queue,p.sz,p.osz,:rfft,p.region,p.buffer,p)
+function plan_inv(p::rMKLFFTPlan{T, MKLFFT_INVERSE, inplace, N, R, B}) where {T, inplace, N, R, B}
+    q = rMKLFFTPlan{T, MKLFFT_FORWARD, inplace, N, R, B}(p.handle, p.queue, p.sz, p.osz, :rfft, p.region, p.buffer, p)
     p.pinv = q
-    ScaledPlan(q, 1/normalization_factor(p.sz, p.region))
+    return ScaledPlan(q, 1 / normalization_factor(p.sz, p.region))
 end
 
-function Base.show(io::IO, p::MKLFFTPlan{T,K,inplace}) where {T,K,inplace}
+function Base.show(io::IO, p::MKLFFTPlan{T, K, inplace}) where {T, K, inplace}
     print(io, inplace ? "oneMKL FFT in-place " : "oneMKL FFT ", K ? "forward" : "inverse", " plan for ")
-    if isempty(p.sz); print(io, "0-dimensional") else print(io, join(p.sz, "×")) end
-    print(io, " oneArray of ", T)
+    if isempty(p.sz)
+        print(io, "0-dimensional")
+    else
+        print(io, join(p.sz, "×"))
+    end
+    return print(io, " oneArray of ", T)
 end
 
 # Plan constructors
-function _create_descriptor(sz::NTuple{N,Int}, T::Type, complex::Bool) where {N}
-    prec = T<:Float64 || T<:ComplexF64 ? ONEMKL_DFT_PRECISION_DOUBLE : ONEMKL_DFT_PRECISION_SINGLE
+function _create_descriptor(sz::NTuple{N, Int}, T::Type, complex::Bool) where {N}
+    prec = T <: Float64 || T <: ComplexF64 ? ONEMKL_DFT_PRECISION_DOUBLE : ONEMKL_DFT_PRECISION_SINGLE
     dom = complex ? ONEMKL_DFT_DOMAIN_COMPLEX : ONEMKL_DFT_DOMAIN_REAL
     desc_ref = Ref{onemklDftDescriptor_t}()
     # Create descriptor for the full array dimensions
@@ -109,8 +113,8 @@ function _create_descriptor(sz::NTuple{N,Int}, T::Type, complex::Bool) where {N}
 end
 
 # Complex plans
-function plan_fft(X::oneAPI.oneArray{T,N}, region) where {T<:Union{ComplexF32,ComplexF64},N}
-    R = length(region); reg = NTuple{R,Int}(region)
+function plan_fft(X::oneAPI.oneArray{T, N}, region) where {T <: Union{ComplexF32, ComplexF64}, N}
+    R = length(region); reg = NTuple{R, Int}(region)
     # For now, only support full transforms (all dimensions)
     if reg != ntuple(identity, N)
         error("Partial dimension FFT not yet supported. Region $reg must be $(ntuple(identity, N))")
@@ -119,20 +123,20 @@ function plan_fft(X::oneAPI.oneArray{T,N}, region) where {T<:Union{ComplexF32,Co
     onemklDftSetValueConfigValue(desc, ONEMKL_DFT_PARAM_PLACEMENT, ONEMKL_DFT_VALUE_NOT_INPLACE)
     if N > 1
         # Column-major strides: stride along dimension i is product of sizes of previous dims
-        strides = Vector{Int64}(undef, N+1); strides[1]=0
+        strides = Vector{Int64}(undef, N + 1); strides[1] = 0
         prod = 1
         @inbounds for i in 1:N
-            strides[i+1] = prod
-            prod *= size(X,i)
+            strides[i + 1] = prod
+            prod *= size(X, i)
         end
         onemklDftSetValueInt64Array(desc, ONEMKL_DFT_PARAM_FWD_STRIDES, pointer(strides), length(strides))
         onemklDftSetValueInt64Array(desc, ONEMKL_DFT_PARAM_BWD_STRIDES, pointer(strides), length(strides))
     end
     stc = onemklDftCommit(desc, q); stc == 0 || error("commit failed ($stc)")
-    return cMKLFFTPlan{T,MKLFFT_FORWARD,false,N,R,Nothing}(desc,q,size(X),size(X),false,reg,nothing,nothing)
+    return cMKLFFTPlan{T, MKLFFT_FORWARD, false, N, R, Nothing}(desc, q, size(X), size(X), false, reg, nothing, nothing)
 end
-function plan_bfft(X::oneAPI.oneArray{T,N}, region) where {T<:Union{ComplexF32,ComplexF64},N}
-    R = length(region); reg = NTuple{R,Int}(region)
+function plan_bfft(X::oneAPI.oneArray{T, N}, region) where {T <: Union{ComplexF32, ComplexF64}, N}
+    R = length(region); reg = NTuple{R, Int}(region)
     # For now, only support full transforms (all dimensions)
     if reg != ntuple(identity, N)
         error("Partial dimension FFT not yet supported. Region $reg must be $(ntuple(identity, N))")
@@ -140,87 +144,87 @@ function plan_bfft(X::oneAPI.oneArray{T,N}, region) where {T<:Union{ComplexF32,C
     desc, q = _create_descriptor(size(X), T, true)
     onemklDftSetValueConfigValue(desc, ONEMKL_DFT_PARAM_PLACEMENT, ONEMKL_DFT_VALUE_NOT_INPLACE)
     if N > 1
-        strides = Vector{Int64}(undef, N+1); strides[1]=0; prod=1
+        strides = Vector{Int64}(undef, N + 1); strides[1] = 0; prod = 1
         @inbounds for i in 1:N
-            strides[i+1]=prod; prod*=size(X,i)
+            strides[i + 1] = prod; prod *= size(X, i)
         end
         onemklDftSetValueInt64Array(desc, ONEMKL_DFT_PARAM_FWD_STRIDES, pointer(strides), length(strides))
         onemklDftSetValueInt64Array(desc, ONEMKL_DFT_PARAM_BWD_STRIDES, pointer(strides), length(strides))
     end
     stc = onemklDftCommit(desc, q); stc == 0 || error("commit failed ($stc)")
-    return cMKLFFTPlan{T,MKLFFT_INVERSE,false,N,R,Nothing}(desc,q,size(X),size(X),false,reg,nothing,nothing)
+    return cMKLFFTPlan{T, MKLFFT_INVERSE, false, N, R, Nothing}(desc, q, size(X), size(X), false, reg, nothing, nothing)
 end
 
 # In-place (provide separate methods)
-function plan_fft!(X::oneAPI.oneArray{T,N}, region) where {T<:Union{ComplexF32,ComplexF64},N}
-    R = length(region); reg = NTuple{R,Int}(region)
+function plan_fft!(X::oneAPI.oneArray{T, N}, region) where {T <: Union{ComplexF32, ComplexF64}, N}
+    R = length(region); reg = NTuple{R, Int}(region)
     # For now, only support full transforms (all dimensions)
     if reg != ntuple(identity, N)
         error("Partial dimension FFT not yet supported. Region $reg must be $(ntuple(identity, N))")
     end
-    desc,q = _create_descriptor(size(X),T,true)
+    desc, q = _create_descriptor(size(X), T, true)
     onemklDftSetValueConfigValue(desc, ONEMKL_DFT_PARAM_PLACEMENT, ONEMKL_DFT_VALUE_INPLACE)
     if N > 1
-        strides = Vector{Int64}(undef, N+1); strides[1]=0; prod=1
+        strides = Vector{Int64}(undef, N + 1); strides[1] = 0; prod = 1
         @inbounds for i in 1:N
-            strides[i+1]=prod; prod*=size(X,i)
+            strides[i + 1] = prod; prod *= size(X, i)
         end
         onemklDftSetValueInt64Array(desc, ONEMKL_DFT_PARAM_FWD_STRIDES, pointer(strides), length(strides))
         onemklDftSetValueInt64Array(desc, ONEMKL_DFT_PARAM_BWD_STRIDES, pointer(strides), length(strides))
     end
     stc = onemklDftCommit(desc, q); stc == 0 || error("commit failed ($stc)")
-    cMKLFFTPlan{T,MKLFFT_FORWARD,true,N,R,Nothing}(desc,q,size(X),size(X),false,reg,nothing,nothing)
+    return cMKLFFTPlan{T, MKLFFT_FORWARD, true, N, R, Nothing}(desc, q, size(X), size(X), false, reg, nothing, nothing)
 end
-function plan_bfft!(X::oneAPI.oneArray{T,N}, region) where {T<:Union{ComplexF32,ComplexF64},N}
-    R = length(region); reg = NTuple{R,Int}(region)
+function plan_bfft!(X::oneAPI.oneArray{T, N}, region) where {T <: Union{ComplexF32, ComplexF64}, N}
+    R = length(region); reg = NTuple{R, Int}(region)
     # For now, only support full transforms (all dimensions)
     if reg != ntuple(identity, N)
         error("Partial dimension FFT not yet supported. Region $reg must be $(ntuple(identity, N))")
     end
-    desc,q = _create_descriptor(size(X),T,true)
+    desc, q = _create_descriptor(size(X), T, true)
     onemklDftSetValueConfigValue(desc, ONEMKL_DFT_PARAM_PLACEMENT, ONEMKL_DFT_VALUE_INPLACE)
     if N > 1
-        strides = Vector{Int64}(undef, N+1); strides[1]=0; prod=1
+        strides = Vector{Int64}(undef, N + 1); strides[1] = 0; prod = 1
         @inbounds for i in 1:N
-            strides[i+1]=prod; prod*=size(X,i)
+            strides[i + 1] = prod; prod *= size(X, i)
         end
         onemklDftSetValueInt64Array(desc, ONEMKL_DFT_PARAM_FWD_STRIDES, pointer(strides), length(strides))
         onemklDftSetValueInt64Array(desc, ONEMKL_DFT_PARAM_BWD_STRIDES, pointer(strides), length(strides))
     end
     stc = onemklDftCommit(desc, q); stc == 0 || error("commit failed ($stc)")
-    cMKLFFTPlan{T,MKLFFT_INVERSE,true,N,R,Nothing}(desc,q,size(X),size(X),false,reg,nothing,nothing)
+    return cMKLFFTPlan{T, MKLFFT_INVERSE, true, N, R, Nothing}(desc, q, size(X), size(X), false, reg, nothing, nothing)
 end
 
 # Real input methods - convert to complex like FFTW does
-function plan_fft(X::oneAPI.oneArray{T,N}, region) where {T<:Union{Float32,Float64},N}
+function plan_fft(X::oneAPI.oneArray{T, N}, region) where {T <: Union{Float32, Float64}, N}
     CT = Complex{T}
     # Create a complex plan by converting the real array to complex
     X_complex = oneAPI.oneArray{CT}(undef, size(X))
-    plan_fft(X_complex, region)
+    return plan_fft(X_complex, region)
 end
 
-function plan_bfft(X::oneAPI.oneArray{T,N}, region) where {T<:Union{Float32,Float64},N}
+function plan_bfft(X::oneAPI.oneArray{T, N}, region) where {T <: Union{Float32, Float64}, N}
     CT = Complex{T}
     # Create a complex plan by converting the real array to complex
     X_complex = oneAPI.oneArray{CT}(undef, size(X))
-    plan_bfft(X_complex, region)
+    return plan_bfft(X_complex, region)
 end
 
-function plan_fft!(X::oneAPI.oneArray{T,N}, region) where {T<:Union{Float32,Float64},N}
+function plan_fft!(X::oneAPI.oneArray{T, N}, region) where {T <: Union{Float32, Float64}, N}
     error("In-place FFT not supported for real input arrays. Use plan_fft instead.")
 end
 
-function plan_bfft!(X::oneAPI.oneArray{T,N}, region) where {T<:Union{Float32,Float64},N}
+function plan_bfft!(X::oneAPI.oneArray{T, N}, region) where {T <: Union{Float32, Float64}, N}
     error("In-place FFT not supported for real input arrays. Use plan_bfft instead.")
 end
 
 # Real forward (out-of-place) - supports multi-dimensional transforms
-function plan_rfft(X::oneAPI.oneArray{T,N}, region) where {T<:Union{Float32,Float64},N}
+function plan_rfft(X::oneAPI.oneArray{T, N}, region) where {T <: Union{Float32, Float64}, N}
     # Convert region to tuple if it's a range
     if isa(region, AbstractUnitRange)
         region = tuple(region...)
     end
-    R = length(region); reg = NTuple{R,Int}(region)
+    R = length(region); reg = NTuple{R, Int}(region)
 
     # For single dimension transforms, use the optimized oneMKL real FFT
     if R == 1 && reg[1] == 1
@@ -234,12 +238,12 @@ function plan_rfft(X::oneAPI.oneArray{T,N}, region) where {T<:Union{Float32,Floa
 end
 
 # Single-dimension real FFT using oneMKL (optimized path)
-function _plan_rfft_1d(X::oneAPI.oneArray{T,N}, reg::NTuple{1,Int}) where {T<:Union{Float32,Float64},N}
+function _plan_rfft_1d(X::oneAPI.oneArray{T, N}, reg::NTuple{1, Int}) where {T <: Union{Float32, Float64}, N}
     # Create 1D descriptor for the transform dimension
-    desc,q = _create_descriptor((size(X, reg[1]),), T, false)
+    desc, q = _create_descriptor((size(X, reg[1]),), T, false)
     xdims = size(X)
     # output along first dim becomes N/2+1
-    ydims = Base.setindex(xdims, div(xdims[1],2)+1, 1)
+    ydims = Base.setindex(xdims, div(xdims[1], 2) + 1, 1)
     buffer = oneAPI.oneArray{Complex{T}}(undef, ydims)
     onemklDftSetValueConfigValue(desc, ONEMKL_DFT_PARAM_PLACEMENT, ONEMKL_DFT_VALUE_NOT_INPLACE)
 
@@ -255,18 +259,18 @@ function _plan_rfft_1d(X::oneAPI.oneArray{T,N}, reg::NTuple{1,Int}) where {T<:Un
 
     stc = onemklDftCommit(desc, q); stc == 0 || error("commit failed ($stc)")
     R = length(reg)
-    rMKLFFTPlan{T,MKLFFT_FORWARD,false,N,R,typeof(buffer)}(desc,q,xdims,ydims,:rfft,reg,buffer,nothing)
+    return rMKLFFTPlan{T, MKLFFT_FORWARD, false, N, R, typeof(buffer)}(desc, q, xdims, ydims, :rfft, reg, buffer, nothing)
 end
 
 # Multi-dimensional real FFT using complex FFT approach
-struct ComplexBasedRealFFTPlan{T,N,R} <: MKLFFTPlan{T,MKLFFT_FORWARD,false}
-    complex_plan::cMKLFFTPlan{Complex{T},MKLFFT_FORWARD,false,N,R,Nothing}
-    sz::NTuple{N,Int}
-    osz::NTuple{N,Int}
-    region::NTuple{R,Int}
+struct ComplexBasedRealFFTPlan{T, N, R} <: MKLFFTPlan{T, MKLFFT_FORWARD, false}
+    complex_plan::cMKLFFTPlan{Complex{T}, MKLFFT_FORWARD, false, N, R, Nothing}
+    sz::NTuple{N, Int}
+    osz::NTuple{N, Int}
+    region::NTuple{R, Int}
 end
 
-function _plan_rfft_nd(X::oneAPI.oneArray{T,N}, reg::NTuple{R,Int}) where {T<:Union{Float32,Float64},N,R}
+function _plan_rfft_nd(X::oneAPI.oneArray{T, N}, reg::NTuple{R, Int}) where {T <: Union{Float32, Float64}, N, R}
     # Create complex version for planning
     X_complex = oneAPI.oneArray{Complex{T}}(undef, size(X))
     complex_plan = plan_fft(X_complex, reg)
@@ -281,18 +285,22 @@ function _plan_rfft_nd(X::oneAPI.oneArray{T,N}, reg::NTuple{R,Int}) where {T<:Un
         end
     end
 
-    ComplexBasedRealFFTPlan{T,N,R}(complex_plan, xdims, ydims, reg)
+    return ComplexBasedRealFFTPlan{T, N, R}(complex_plan, xdims, ydims, reg)
 end
 
 # Show method for complex-based plan
 function Base.show(io::IO, p::ComplexBasedRealFFTPlan{T}) where {T}
     print(io, "oneMKL FFT forward plan for ")
-    if isempty(p.sz); print(io, "0-dimensional") else print(io, join(p.sz, "×")) end
-    print(io, " oneArray of ", T, " (multi-dimensional via complex FFT)")
+    if isempty(p.sz)
+        print(io, "0-dimensional")
+    else
+        print(io, join(p.sz, "×"))
+    end
+    return print(io, " oneArray of ", T, " (multi-dimensional via complex FFT)")
 end
 
 # Execution for complex-based real FFT plan
-function Base.:*(p::ComplexBasedRealFFTPlan{T,N,R}, X::oneAPI.oneArray{T}) where {T,N,R}
+function Base.:*(p::ComplexBasedRealFFTPlan{T, N, R}, X::oneAPI.oneArray{T}) where {T, N, R}
     # Convert to complex
     X_complex = Complex{T}.(X)
 
@@ -316,14 +324,13 @@ function Base.:*(p::ComplexBasedRealFFTPlan{T,N,R}, X::oneAPI.oneArray{T}) where
 end
 
 
-
 # Real inverse (complex->real) requires complex input shape - supports multi-dimensional transforms
-function plan_brfft(X::oneAPI.oneArray{T,N}, d::Integer, region) where {T<:Union{ComplexF32,ComplexF64},N}
+function plan_brfft(X::oneAPI.oneArray{T, N}, d::Integer, region) where {T <: Union{ComplexF32, ComplexF64}, N}
     # Convert region to tuple if it's a range
     if isa(region, AbstractUnitRange)
         region = tuple(region...)
     end
-    R = length(region); reg = NTuple{R,Int}(region)
+    R = length(region); reg = NTuple{R, Int}(region)
 
     # For single dimension transforms along first dim, use optimized oneMKL path
     if R == 1 && reg[1] == 1
@@ -335,13 +342,13 @@ function plan_brfft(X::oneAPI.oneArray{T,N}, d::Integer, region) where {T<:Union
 end
 
 # Single-dimension real inverse FFT using oneMKL (optimized path)
-function _plan_brfft_1d(X::oneAPI.oneArray{T,N}, d::Integer, reg::NTuple{1,Int}) where {T<:Union{ComplexF32,ComplexF64},N}
+function _plan_brfft_1d(X::oneAPI.oneArray{T, N}, d::Integer, reg::NTuple{1, Int}) where {T <: Union{ComplexF32, ComplexF64}, N}
     # Extract underlying real type R from Complex{R}
     @assert T <: Complex
     RT = T.parameters[1]
 
     # Create 1D descriptor for the transform dimension
-    desc,q = _create_descriptor((d,), RT, false)
+    desc, q = _create_descriptor((d,), RT, false)
     xdims = size(X)
     ydims = Base.setindex(xdims, d, 1)
     buffer = oneAPI.oneArray{T}(undef, xdims) # copy for safety
@@ -355,19 +362,19 @@ function _plan_brfft_1d(X::oneAPI.oneArray{T,N}, d::Integer, reg::NTuple{1,Int})
 
     stc = onemklDftCommit(desc, q); stc == 0 || error("commit failed ($stc)")
     R = length(reg)
-    rMKLFFTPlan{T,MKLFFT_INVERSE,false,N,R,typeof(buffer)}(desc,q,xdims,ydims,:brfft,reg,buffer,nothing)
+    return rMKLFFTPlan{T, MKLFFT_INVERSE, false, N, R, typeof(buffer)}(desc, q, xdims, ydims, :brfft, reg, buffer, nothing)
 end
 
 # Multi-dimensional real inverse FFT using complex FFT approach
-struct ComplexBasedRealIFFTPlan{T,N,R} <: MKLFFTPlan{T,MKLFFT_INVERSE,false}
-    complex_plan::cMKLFFTPlan{T,MKLFFT_INVERSE,false,N,R,Nothing}
-    sz::NTuple{N,Int}
-    osz::NTuple{N,Int}
-    region::NTuple{R,Int}
+struct ComplexBasedRealIFFTPlan{T, N, R} <: MKLFFTPlan{T, MKLFFT_INVERSE, false}
+    complex_plan::cMKLFFTPlan{T, MKLFFT_INVERSE, false, N, R, Nothing}
+    sz::NTuple{N, Int}
+    osz::NTuple{N, Int}
+    region::NTuple{R, Int}
     d::Int  # Original size of the reduced dimension
 end
 
-function _plan_brfft_nd(X::oneAPI.oneArray{T,N}, d::Integer, reg::NTuple{R,Int}) where {T<:Union{ComplexF32,ComplexF64},N,R}
+function _plan_brfft_nd(X::oneAPI.oneArray{T, N}, d::Integer, reg::NTuple{R, Int}) where {T <: Union{ComplexF32, ComplexF64}, N, R}
     # Calculate the full complex array size (before real FFT reduction)
     xdims = size(X)
     full_complex_dims = ntuple(N) do i
@@ -382,18 +389,22 @@ function _plan_brfft_nd(X::oneAPI.oneArray{T,N}, d::Integer, reg::NTuple{R,Int})
     X_complex_full = oneAPI.oneArray{T}(undef, full_complex_dims)
     complex_plan = plan_bfft(X_complex_full, reg)
 
-    ComplexBasedRealIFFTPlan{T,N,R}(complex_plan, xdims, full_complex_dims, reg, d)
+    return ComplexBasedRealIFFTPlan{T, N, R}(complex_plan, xdims, full_complex_dims, reg, d)
 end
 
 # Show method for complex-based inverse plan
 function Base.show(io::IO, p::ComplexBasedRealIFFTPlan{T}) where {T}
     print(io, "oneMKL FFT inverse plan for ")
-    if isempty(p.sz); print(io, "0-dimensional") else print(io, join(p.sz, "×")) end
-    print(io, " oneArray of ", T, " (multi-dimensional via complex FFT)")
+    if isempty(p.sz)
+        print(io, "0-dimensional")
+    else
+        print(io, join(p.sz, "×"))
+    end
+    return print(io, " oneArray of ", T, " (multi-dimensional via complex FFT)")
 end
 
 # Execution for complex-based real inverse FFT plan
-function Base.:*(p::ComplexBasedRealIFFTPlan{T,N,R}, X::oneAPI.oneArray{T}) where {T,N,R}
+function Base.:*(p::ComplexBasedRealIFFTPlan{T, N, R}, X::oneAPI.oneArray{T}) where {T, N, R}
     # Reconstruct full complex array by exploiting conjugate symmetry
     # This is a simplified approach - for full accuracy, we'd need to properly
     # reconstruct the conjugate symmetric part
@@ -435,7 +446,7 @@ function Base.:*(p::ComplexBasedRealIFFTPlan{T,N,R}, X::oneAPI.oneArray{T}) wher
 end
 
 # Inverse plan for complex-based real FFT plans
-function plan_inv(p::ComplexBasedRealFFTPlan{T,N,R}) where {T,N,R}
+function plan_inv(p::ComplexBasedRealFFTPlan{T, N, R}) where {T, N, R}
     # For real FFT inverse, we need plan_brfft functionality
     # The first dimension in the region should be the one that was reduced
     first_dim = minimum(p.region)
@@ -443,18 +454,17 @@ function plan_inv(p::ComplexBasedRealFFTPlan{T,N,R}) where {T,N,R}
 
     # Create inverse plan using our new multi-dimensional brfft
     brfft_plan = _plan_brfft_nd(oneAPI.oneArray{Complex{T}}(undef, p.osz), d, p.region)
-    ScaledPlan(brfft_plan, 1/normalization_factor(p.sz, p.region))
+    return ScaledPlan(brfft_plan, 1 / normalization_factor(p.sz, p.region))
 end
 
 # Inverse plan for complex-based real inverse FFT plans
-function plan_inv(p::ComplexBasedRealIFFTPlan{T,N,R}) where {T,N,R}
+function plan_inv(p::ComplexBasedRealIFFTPlan{T, N, R}) where {T, N, R}
     # Create forward plan
     forward_plan = _plan_rfft_nd(oneAPI.oneArray{real(T)}(undef, p.osz), p.region)
-    ScaledPlan(forward_plan, 1/normalization_factor(p.osz, p.region))
+    return ScaledPlan(forward_plan, 1 / normalization_factor(p.osz, p.region))
 end
 
 
-
 # Convenience no-region methods use all dimensions in order
 plan_fft(X::oneAPI.oneArray) = plan_fft(X, ntuple(identity, ndims(X)))
 plan_bfft(X::oneAPI.oneArray) = plan_bfft(X, ntuple(identity, ndims(X)))
@@ -467,111 +477,119 @@ plan_brfft(X::oneAPI.oneArray, d::Integer) = plan_brfft(X, d, ntuple(identity, n
 const plan_ifft = plan_bfft
 const plan_ifft! = plan_bfft!
 # plan_irfft should be normalized, unlike plan_brfft
-plan_irfft(X::oneAPI.oneArray{T,N}, d::Integer, region) where {T,N} = begin
+plan_irfft(X::oneAPI.oneArray{T, N}, d::Integer, region) where {T, N} = begin
     p = plan_brfft(X, d, region)
-    ScaledPlan(p, 1/normalization_factor(p.sz, p.region))
+    ScaledPlan(p, 1 / normalization_factor(p.sz, p.region))
 end
-plan_irfft(X::oneAPI.oneArray{T,N}, d::Integer) where {T,N} = plan_irfft(X, d, (1,))
+plan_irfft(X::oneAPI.oneArray{T, N}, d::Integer) where {T, N} = plan_irfft(X, d, (1,))
 
 # Inversion
 Base.inv(p::MKLFFTPlan) = plan_inv(p)
 
 # High-level wrappers operating like CPU FFTW versions.
-function fft(X::oneAPI.oneArray{T}) where {T<:Union{ComplexF32,ComplexF64}}
-    (plan_fft(X) * X)
+function fft(X::oneAPI.oneArray{T}) where {T <: Union{ComplexF32, ComplexF64}}
+    return (plan_fft(X) * X)
 end
-function ifft(X::oneAPI.oneArray{T}) where {T<:Union{ComplexF32,ComplexF64}}
+function ifft(X::oneAPI.oneArray{T}) where {T <: Union{ComplexF32, ComplexF64}}
     p = plan_bfft(X)
     # Apply normalization for ifft (unlike bfft which is unnormalized)
     scaling = one(T) / normalization_factor(size(X), ntuple(identity, ndims(X)))
-    scaling * (p * X)
+    return scaling * (p * X)
 end
-function fft!(X::oneAPI.oneArray{T}) where {T<:Union{ComplexF32,ComplexF64}}
+function fft!(X::oneAPI.oneArray{T}) where {T <: Union{ComplexF32, ComplexF64}}
     (plan_fft!(X) * X; X)
 end
-function ifft!(X::oneAPI.oneArray{T}) where {T<:Union{ComplexF32,ComplexF64}}
+function ifft!(X::oneAPI.oneArray{T}) where {T <: Union{ComplexF32, ComplexF64}}
     p = plan_bfft!(X)
     # Apply normalization for ifft! (unlike bfft! which is unnormalized)
     scaling = one(T) / normalization_factor(size(X), ntuple(identity, ndims(X)))
     p * X
     X .*= scaling
-    X
+    return X
 end
-function rfft(X::oneAPI.oneArray{T}) where {T<:Union{Float32,Float64}}
-    (plan_rfft(X) * X)
+function rfft(X::oneAPI.oneArray{T}) where {T <: Union{Float32, Float64}}
+    return (plan_rfft(X) * X)
 end
-function irfft(X::oneAPI.oneArray{T}, d::Integer) where {T<:Union{ComplexF32,ComplexF64}}
+function irfft(X::oneAPI.oneArray{T}, d::Integer) where {T <: Union{ComplexF32, ComplexF64}}
     # Use the normalized plan_irfft instead of unnormalized plan_brfft
-    (plan_irfft(X, d) * X)
+    return (plan_irfft(X, d) * X)
 end
 
 # Execution helpers
-_rawptr(a::oneAPI.oneArray{T}) where T = reinterpret(Ptr{Cvoid}, pointer(a))
+_rawptr(a::oneAPI.oneArray{T}) where {T} = reinterpret(Ptr{Cvoid}, pointer(a))
 
-function _exec!(p::cMKLFFTPlan{T,MKLFFT_FORWARD,true}, X::oneAPI.oneArray{T}) where T
-    st = onemklDftComputeForward(p.handle, _rawptr(X)); st==0 || error("forward FFT failed ($st)"); X
+function _exec!(p::cMKLFFTPlan{T, MKLFFT_FORWARD, true}, X::oneAPI.oneArray{T}) where {T}
+    st = onemklDftComputeForward(p.handle, _rawptr(X)); st == 0 || error("forward FFT failed ($st)")
+    return X
 end
-function _exec!(p::cMKLFFTPlan{T,MKLFFT_INVERSE,true}, X::oneAPI.oneArray{T}) where T
-    st = onemklDftComputeBackward(p.handle, _rawptr(X)); st==0 || error("inverse FFT failed ($st)"); X
+function _exec!(p::cMKLFFTPlan{T, MKLFFT_INVERSE, true}, X::oneAPI.oneArray{T}) where {T}
+    st = onemklDftComputeBackward(p.handle, _rawptr(X)); st == 0 || error("inverse FFT failed ($st)")
+    return X
 end
-function _exec!(p::cMKLFFTPlan{T,K,false}, X::oneAPI.oneArray{T}, Y::oneAPI.oneArray{T}) where {T,K}
-    st = (K==MKLFFT_FORWARD ? onemklDftComputeForwardOutOfPlace : onemklDftComputeBackwardOutOfPlace)(p.handle, _rawptr(X), _rawptr(Y)); st==0 || error("FFT failed ($st)"); Y
+function _exec!(p::cMKLFFTPlan{T, K, false}, X::oneAPI.oneArray{T}, Y::oneAPI.oneArray{T}) where {T, K}
+    st = (K == MKLFFT_FORWARD ? onemklDftComputeForwardOutOfPlace : onemklDftComputeBackwardOutOfPlace)(p.handle, _rawptr(X), _rawptr(Y)); st == 0 || error("FFT failed ($st)")
+    return Y
 end
 
 # Real forward
-function _exec!(p::rMKLFFTPlan{T,MKLFFT_FORWARD,false}, X::oneAPI.oneArray{T}, Y::oneAPI.oneArray{Complex{T}}) where T
-    st = onemklDftComputeForwardOutOfPlace(p.handle, _rawptr(X), _rawptr(Y)); st==0 || error("rfft failed ($st)"); Y
+function _exec!(p::rMKLFFTPlan{T, MKLFFT_FORWARD, false}, X::oneAPI.oneArray{T}, Y::oneAPI.oneArray{Complex{T}}) where {T}
+    st = onemklDftComputeForwardOutOfPlace(p.handle, _rawptr(X), _rawptr(Y)); st == 0 || error("rfft failed ($st)")
+    return Y
 end
 # Real inverse (complex -> real)
-function _exec!(p::rMKLFFTPlan{T,MKLFFT_INVERSE,false}, X::oneAPI.oneArray{T}, Y::oneAPI.oneArray{R}) where {R,T<:Complex{R}}
-    st = onemklDftComputeBackwardOutOfPlace(p.handle, _rawptr(X), _rawptr(Y)); st==0 || error("brfft failed ($st)"); Y
+function _exec!(p::rMKLFFTPlan{T, MKLFFT_INVERSE, false}, X::oneAPI.oneArray{T}, Y::oneAPI.oneArray{R}) where {R, T <: Complex{R}}
+    st = onemklDftComputeBackwardOutOfPlace(p.handle, _rawptr(X), _rawptr(Y)); st == 0 || error("brfft failed ($st)")
+    return Y
 end
 
 # Public API similar to AMDGPU
-function Base.:*(p::cMKLFFTPlan{T,K,true}, X::oneAPI.oneArray{T}) where {T,K}
-    _exec!(p,X)
+function Base.:*(p::cMKLFFTPlan{T, K, true}, X::oneAPI.oneArray{T}) where {T, K}
+    return _exec!(p, X)
 end
-function Base.:*(p::cMKLFFTPlan{T,K,false}, X::oneAPI.oneArray{T}) where {T,K}
-    Y = oneAPI.oneArray{T}(undef, p.osz); _exec!(p,X,Y)
+function Base.:*(p::cMKLFFTPlan{T, K, false}, X::oneAPI.oneArray{T}) where {T, K}
+    Y = oneAPI.oneArray{T}(undef, p.osz)
+    return _exec!(p, X, Y)
 end
-function LinearAlgebra.mul!(Y::oneAPI.oneArray{T}, p::cMKLFFTPlan{T,K,false}, X::oneAPI.oneArray{T}) where {T,K}
-    _exec!(p,X,Y)
+function LinearAlgebra.mul!(Y::oneAPI.oneArray{T}, p::cMKLFFTPlan{T, K, false}, X::oneAPI.oneArray{T}) where {T, K}
+    return _exec!(p, X, Y)
 end
 
 # Real forward
-function Base.:*(p::rMKLFFTPlan{T,MKLFFT_FORWARD,false}, X::oneAPI.oneArray{T}) where {T<:Union{Float32,Float64}}
-    Y = oneAPI.oneArray{Complex{T}}(undef, p.osz); _exec!(p,X,Y)
+function Base.:*(p::rMKLFFTPlan{T, MKLFFT_FORWARD, false}, X::oneAPI.oneArray{T}) where {T <: Union{Float32, Float64}}
+    Y = oneAPI.oneArray{Complex{T}}(undef, p.osz)
+    return _exec!(p, X, Y)
 end
-function LinearAlgebra.mul!(Y::oneAPI.oneArray{Complex{T}}, p::rMKLFFTPlan{T,MKLFFT_FORWARD,false}, X::oneAPI.oneArray{T}) where {T<:Union{Float32,Float64}}
-    _exec!(p,X,Y)
+function LinearAlgebra.mul!(Y::oneAPI.oneArray{Complex{T}}, p::rMKLFFTPlan{T, MKLFFT_FORWARD, false}, X::oneAPI.oneArray{T}) where {T <: Union{Float32, Float64}}
+    return _exec!(p, X, Y)
 end
 # Real inverse
-function Base.:*(p::rMKLFFTPlan{T,MKLFFT_INVERSE,false}, X::oneAPI.oneArray{T}) where {R,T<:Complex{R}}
-    Y = oneAPI.oneArray{R}(undef, p.osz); _exec!(p,X,Y)
+function Base.:*(p::rMKLFFTPlan{T, MKLFFT_INVERSE, false}, X::oneAPI.oneArray{T}) where {R, T <: Complex{R}}
+    Y = oneAPI.oneArray{R}(undef, p.osz)
+    return _exec!(p, X, Y)
 end
-function LinearAlgebra.mul!(Y::oneAPI.oneArray{R}, p::rMKLFFTPlan{T,MKLFFT_INVERSE,false}, X::oneAPI.oneArray{T}) where {R,T<:Complex{R}}
-    _exec!(p,X,Y)
+function LinearAlgebra.mul!(Y::oneAPI.oneArray{R}, p::rMKLFFTPlan{T, MKLFFT_INVERSE, false}, X::oneAPI.oneArray{T}) where {R, T <: Complex{R}}
+    return _exec!(p, X, Y)
 end
 
 # Support for applying complex plans to real arrays (convert real to complex first)
-function Base.:*(p::cMKLFFTPlan{T,K,false}, X::oneAPI.oneArray{R}) where {T,K,R<:Union{Float32,Float64}}
+function Base.:*(p::cMKLFFTPlan{T, K, false}, X::oneAPI.oneArray{R}) where {T, K, R <: Union{Float32, Float64}}
     # Only allow if T is the complex version of R
     if T != Complex{R}
         error("Type mismatch: plan expects $(T) but got $(R)")
     end
     # Convert real input to complex
     X_complex = complex.(X)
-    p * X_complex
+    return p * X_complex
 end
 
-function LinearAlgebra.mul!(Y::oneAPI.oneArray{T}, p::cMKLFFTPlan{T,K,false}, X::oneAPI.oneArray{R}) where {T,K,R<:Union{Float32,Float64}}
+function LinearAlgebra.mul!(Y::oneAPI.oneArray{T}, p::cMKLFFTPlan{T, K, false}, X::oneAPI.oneArray{R}) where {T, K, R <: Union{Float32, Float64}}
     # Only allow if T is the complex version of R
     if T != Complex{R}
         error("Type mismatch: plan expects $(T) but got $(R)")
     end
     # Convert real input to complex
     X_complex = complex.(X)
-    _exec!(p, X_complex, Y)
+    return _exec!(p, X_complex, Y)
 end
 
 end # module FFT
diff --git a/lib/support/liboneapi_support.jl b/lib/support/liboneapi_support.jl
index 06d8bee..0ea694b 100644
--- a/lib/support/liboneapi_support.jl
+++ b/lib/support/liboneapi_support.jl
@@ -7111,122 +7111,160 @@ mutable struct onemklDftDescriptor_st end
 const onemklDftDescriptor_t = Ptr{onemklDftDescriptor_st}
 
 function onemklDftCreate1D(desc, precision, domain, length)
-    @ccall liboneapi_support.onemklDftCreate1D(desc::Ptr{onemklDftDescriptor_t},
-                                               precision::onemklDftPrecision,
-                                               domain::onemklDftDomain, length::Int64)::Cint
+    return @ccall liboneapi_support.onemklDftCreate1D(
+        desc::Ptr{onemklDftDescriptor_t},
+        precision::onemklDftPrecision,
+        domain::onemklDftDomain, length::Int64
+    )::Cint
 end
 
 function onemklDftCreateND(desc, precision, domain, dim, lengths)
-    @ccall liboneapi_support.onemklDftCreateND(desc::Ptr{onemklDftDescriptor_t},
-                                               precision::onemklDftPrecision,
-                                               domain::onemklDftDomain, dim::Int64,
-                                               lengths::Ptr{Int64})::Cint
+    return @ccall liboneapi_support.onemklDftCreateND(
+        desc::Ptr{onemklDftDescriptor_t},
+        precision::onemklDftPrecision,
+        domain::onemklDftDomain, dim::Int64,
+        lengths::Ptr{Int64}
+    )::Cint
 end
 
 function onemklDftDestroy(desc)
-    @ccall liboneapi_support.onemklDftDestroy(desc::onemklDftDescriptor_t)::Cint
+    return @ccall liboneapi_support.onemklDftDestroy(desc::onemklDftDescriptor_t)::Cint
 end
 
 function onemklDftCommit(desc, queue)
-    @ccall liboneapi_support.onemklDftCommit(desc::onemklDftDescriptor_t,
-                                             queue::syclQueue_t)::Cint
+    return @ccall liboneapi_support.onemklDftCommit(
+        desc::onemklDftDescriptor_t,
+        queue::syclQueue_t
+    )::Cint
 end
 
 function onemklDftSetValueInt64(desc, param, value)
-    @ccall liboneapi_support.onemklDftSetValueInt64(desc::onemklDftDescriptor_t,
-                                                    param::onemklDftConfigParam,
-                                                    value::Int64)::Cint
+    return @ccall liboneapi_support.onemklDftSetValueInt64(
+        desc::onemklDftDescriptor_t,
+        param::onemklDftConfigParam,
+        value::Int64
+    )::Cint
 end
 
 function onemklDftSetValueDouble(desc, param, value)
-    @ccall liboneapi_support.onemklDftSetValueDouble(desc::onemklDftDescriptor_t,
-                                                     param::onemklDftConfigParam,
-                                                     value::Cdouble)::Cint
+    return @ccall liboneapi_support.onemklDftSetValueDouble(
+        desc::onemklDftDescriptor_t,
+        param::onemklDftConfigParam,
+        value::Cdouble
+    )::Cint
 end
 
 function onemklDftSetValueInt64Array(desc, param, values, n)
-    @ccall liboneapi_support.onemklDftSetValueInt64Array(desc::onemklDftDescriptor_t,
-                                                         param::onemklDftConfigParam,
-                                                         values::Ptr{Int64}, n::Int64)::Cint
+    return @ccall liboneapi_support.onemklDftSetValueInt64Array(
+        desc::onemklDftDescriptor_t,
+        param::onemklDftConfigParam,
+        values::Ptr{Int64}, n::Int64
+    )::Cint
 end
 
 function onemklDftSetValueConfigValue(desc, param, value)
-    @ccall liboneapi_support.onemklDftSetValueConfigValue(desc::onemklDftDescriptor_t,
-                                                          param::onemklDftConfigParam,
-                                                          value::onemklDftConfigValue)::Cint
+    return @ccall liboneapi_support.onemklDftSetValueConfigValue(
+        desc::onemklDftDescriptor_t,
+        param::onemklDftConfigParam,
+        value::onemklDftConfigValue
+    )::Cint
 end
 
 function onemklDftGetValueInt64(desc, param, value)
-    @ccall liboneapi_support.onemklDftGetValueInt64(desc::onemklDftDescriptor_t,
-                                                    param::onemklDftConfigParam,
-                                                    value::Ptr{Int64})::Cint
+    return @ccall liboneapi_support.onemklDftGetValueInt64(
+        desc::onemklDftDescriptor_t,
+        param::onemklDftConfigParam,
+        value::Ptr{Int64}
+    )::Cint
 end
 
 function onemklDftGetValueDouble(desc, param, value)
-    @ccall liboneapi_support.onemklDftGetValueDouble(desc::onemklDftDescriptor_t,
-                                                     param::onemklDftConfigParam,
-                                                     value::Ptr{Cdouble})::Cint
+    return @ccall liboneapi_support.onemklDftGetValueDouble(
+        desc::onemklDftDescriptor_t,
+        param::onemklDftConfigParam,
+        value::Ptr{Cdouble}
+    )::Cint
 end
 
 function onemklDftGetValueInt64Array(desc, param, values, n)
-    @ccall liboneapi_support.onemklDftGetValueInt64Array(desc::onemklDftDescriptor_t,
-                                                         param::onemklDftConfigParam,
-                                                         values::Ptr{Int64},
-                                                         n::Ptr{Int64})::Cint
+    return @ccall liboneapi_support.onemklDftGetValueInt64Array(
+        desc::onemklDftDescriptor_t,
+        param::onemklDftConfigParam,
+        values::Ptr{Int64},
+        n::Ptr{Int64}
+    )::Cint
 end
 
 function onemklDftGetValueConfigValue(desc, param, value)
-    @ccall liboneapi_support.onemklDftGetValueConfigValue(desc::onemklDftDescriptor_t,
-                                                          param::onemklDftConfigParam,
-                                                          value::Ptr{onemklDftConfigValue})::Cint
+    return @ccall liboneapi_support.onemklDftGetValueConfigValue(
+        desc::onemklDftDescriptor_t,
+        param::onemklDftConfigParam,
+        value::Ptr{onemklDftConfigValue}
+    )::Cint
 end
 
 function onemklDftComputeForward(desc, inout)
-    @ccall liboneapi_support.onemklDftComputeForward(desc::onemklDftDescriptor_t,
-                                                     inout::Ptr{Cvoid})::Cint
+    return @ccall liboneapi_support.onemklDftComputeForward(
+        desc::onemklDftDescriptor_t,
+        inout::Ptr{Cvoid}
+    )::Cint
 end
 
 function onemklDftComputeForwardOutOfPlace(desc, in, out)
-    @ccall liboneapi_support.onemklDftComputeForwardOutOfPlace(desc::onemklDftDescriptor_t,
-                                                               in::Ptr{Cvoid},
-                                                               out::Ptr{Cvoid})::Cint
+    return @ccall liboneapi_support.onemklDftComputeForwardOutOfPlace(
+        desc::onemklDftDescriptor_t,
+        in::Ptr{Cvoid},
+        out::Ptr{Cvoid}
+    )::Cint
 end
 
 function onemklDftComputeBackward(desc, inout)
-    @ccall liboneapi_support.onemklDftComputeBackward(desc::onemklDftDescriptor_t,
-                                                      inout::Ptr{Cvoid})::Cint
+    return @ccall liboneapi_support.onemklDftComputeBackward(
+        desc::onemklDftDescriptor_t,
+        inout::Ptr{Cvoid}
+    )::Cint
 end
 
 function onemklDftComputeBackwardOutOfPlace(desc, in, out)
-    @ccall liboneapi_support.onemklDftComputeBackwardOutOfPlace(desc::onemklDftDescriptor_t,
-                                                                in::Ptr{Cvoid},
-                                                                out::Ptr{Cvoid})::Cint
+    return @ccall liboneapi_support.onemklDftComputeBackwardOutOfPlace(
+        desc::onemklDftDescriptor_t,
+        in::Ptr{Cvoid},
+        out::Ptr{Cvoid}
+    )::Cint
 end
 
 function onemklDftComputeForwardBuffer(desc, inout)
-    @ccall liboneapi_support.onemklDftComputeForwardBuffer(desc::onemklDftDescriptor_t,
-                                                           inout::Ptr{Cvoid})::Cint
+    return @ccall liboneapi_support.onemklDftComputeForwardBuffer(
+        desc::onemklDftDescriptor_t,
+        inout::Ptr{Cvoid}
+    )::Cint
 end
 
 function onemklDftComputeForwardOutOfPlaceBuffer(desc, in, out)
-    @ccall liboneapi_support.onemklDftComputeForwardOutOfPlaceBuffer(desc::onemklDftDescriptor_t,
-                                                                     in::Ptr{Cvoid},
-                                                                     out::Ptr{Cvoid})::Cint
+    return @ccall liboneapi_support.onemklDftComputeForwardOutOfPlaceBuffer(
+        desc::onemklDftDescriptor_t,
+        in::Ptr{Cvoid},
+        out::Ptr{Cvoid}
+    )::Cint
 end
 
 function onemklDftComputeBackwardBuffer(desc, inout)
-    @ccall liboneapi_support.onemklDftComputeBackwardBuffer(desc::onemklDftDescriptor_t,
-                                                            inout::Ptr{Cvoid})::Cint
+    return @ccall liboneapi_support.onemklDftComputeBackwardBuffer(
+        desc::onemklDftDescriptor_t,
+        inout::Ptr{Cvoid}
+    )::Cint
 end
 
 function onemklDftComputeBackwardOutOfPlaceBuffer(desc, in, out)
-    @ccall liboneapi_support.onemklDftComputeBackwardOutOfPlaceBuffer(desc::onemklDftDescriptor_t,
-                                                                      in::Ptr{Cvoid},
-                                                                      out::Ptr{Cvoid})::Cint
+    return @ccall liboneapi_support.onemklDftComputeBackwardOutOfPlaceBuffer(
+        desc::onemklDftDescriptor_t,
+        in::Ptr{Cvoid},
+        out::Ptr{Cvoid}
+    )::Cint
 end
 
 function onemklDftQueryParamIndices(out, n)
-    @ccall liboneapi_support.onemklDftQueryParamIndices(out::Ptr{Int64}, n::Int64)::Cint
+    return @ccall liboneapi_support.onemklDftQueryParamIndices(out::Ptr{Int64}, n::Int64)::Cint
 end
 
 const ONEMKL_DFT_STATUS_SUCCESS = 0
diff --git a/res/wrap.jl b/res/wrap.jl
index 1d48315..2e9b29f 100644
--- a/res/wrap.jl
+++ b/res/wrap.jl
@@ -112,14 +112,14 @@ using oneAPI_Level_Zero_Headers_jll
 
 function main()
     wrap("ze", oneAPI_Level_Zero_Headers_jll.ze_api)
-    wrap(
-            "support",
-            joinpath(dirname(@__DIR__), "deps", "src", "sycl.h"),
-            joinpath(dirname(@__DIR__), "deps", "src", "onemkl.h"),
-            joinpath(dirname(@__DIR__), "deps", "src", "onemkl_dft.h");
-            dependents=false,
-            include_dirs=[dirname(dirname(oneAPI_Level_Zero_Headers_jll.ze_api))]
-        )
+    return wrap(
+        "support",
+        joinpath(dirname(@__DIR__), "deps", "src", "sycl.h"),
+        joinpath(dirname(@__DIR__), "deps", "src", "onemkl.h"),
+        joinpath(dirname(@__DIR__), "deps", "src", "onemkl_dft.h");
+        dependents = false,
+        include_dirs = [dirname(dirname(oneAPI_Level_Zero_Headers_jll.ze_api))]
+    )
 end
 
 isinteractive() || main()
diff --git a/test/fft.jl b/test/fft.jl
index 1b148df..ef81c21 100644
--- a/test/fft.jl
+++ b/test/fft.jl
@@ -7,39 +7,39 @@ using Random
 Random.seed!(1234)
 
 # Helper to move data to GPU
-gpu(A::AbstractArray{T}) where T = oneAPI.oneArray{T}(A)
+gpu(A::AbstractArray{T}) where {T} = oneAPI.oneArray{T}(A)
 struct _Plan end
 struct _FFT end
 
-const MYRTOL = 1e-5
-const MYATOL = 1e-8
+const MYRTOL = 1.0e-5
+const MYATOL = 1.0e-8
 
-function cmp(a,b; rtol=MYRTOL, atol=MYATOL)
-    @test isapprox(Array(a), Array(b); rtol=rtol, atol=atol)
+function cmp(a, b; rtol = MYRTOL, atol = MYATOL)
+    return @test isapprox(Array(a), Array(b); rtol = rtol, atol = atol)
 end
 
-function test_plan(::_Plan, plan, X::AbstractArray{T,N}) where {T,N}
+function test_plan(::_Plan, plan, X::AbstractArray{T, N}) where {T, N}
     p = plan(X)
     Y = p * X
     return Y
 end
 
-function test_plan(::_FFT, f, X::AbstractArray{T,N}) where {T,N}
+function test_plan(::_FFT, f, X::AbstractArray{T, N}) where {T, N}
     Y = if f === AbstractFFTs.irfft || f === AbstractFFTs.brfft
-        f(X, size(X, ndims(X))*2 - 2)
+        f(X, size(X, ndims(X)) * 2 - 2)
     else
         f(X)
     end
     return Y
 end
 
-function test_plan(t, plan::Function, dim::Tuple, T::Type, iplan=nothing)
+function test_plan(t, plan::Function, dim::Tuple, T::Type, iplan = nothing)
     X = rand(T, dim)
     dX = gpu(X)
     Y = test_plan(t, plan, X)
     dY = test_plan(t, plan, dX)
     cmp(dY, Y)
-    if iplan !== nothing
+    return if iplan !== nothing
         iX = test_plan(t, iplan, Y)
         idX = test_plan(t, iplan, dY)
         cmp(idX, iX)
@@ -47,36 +47,36 @@ function test_plan(t, plan::Function, dim::Tuple, T::Type, iplan=nothing)
 end
 
 @testset "FFT" begin
-@testset "$(length(dim))D" for dim in [(8,), (8,32), (8,32,64)]
-    test_plan(_Plan(), AbstractFFTs.plan_fft, dim, ComplexF32, AbstractFFTs.plan_ifft)
-    test_plan(_Plan(), AbstractFFTs.plan_fft, dim, ComplexF32, AbstractFFTs.plan_bfft)
-    test_plan(_Plan(), AbstractFFTs.plan_fft, dim, Float32, AbstractFFTs.plan_ifft)
-    test_plan(_Plan(), AbstractFFTs.plan_fft, dim, Float32, AbstractFFTs.plan_bfft)
-    test_plan(_Plan(), AbstractFFTs.plan_rfft, dim, Float32)
-    test_plan(_Plan(), AbstractFFTs.plan_fft!, dim, ComplexF32, AbstractFFTs.plan_bfft!)
-    # Not part of FFTW
-    # test_plan(AbstractFFTs.plan_rfft!, Float32)
-    test_plan(_FFT(), AbstractFFTs.fft, dim, ComplexF32, AbstractFFTs.ifft)
-    test_plan(_FFT(), AbstractFFTs.fft, dim, ComplexF32, AbstractFFTs.bfft)
-    if length(dim) == 1  # irfft/brfft only for 1D
-        test_plan(_FFT(), AbstractFFTs.rfft, dim, Float32, AbstractFFTs.irfft)
-        test_plan(_FFT(), AbstractFFTs.rfft, dim, Float32, AbstractFFTs.brfft)
-    end
-    if (ComplexF64 in eltypes) && (Float64 in eltypes)
-        test_plan(_Plan(), AbstractFFTs.plan_fft, dim, ComplexF64, AbstractFFTs.plan_ifft)
-        test_plan(_Plan(), AbstractFFTs.plan_fft, dim, ComplexF64, AbstractFFTs.plan_bfft)
-        test_plan(_Plan(), AbstractFFTs.plan_fft, dim, Float64, AbstractFFTs.plan_ifft)
-        test_plan(_Plan(), AbstractFFTs.plan_fft, dim, Float64, AbstractFFTs.plan_bfft)
-        test_plan(_Plan(), AbstractFFTs.plan_rfft, dim, Float64)
-        test_plan(_Plan(), AbstractFFTs.plan_fft!, dim, ComplexF64, AbstractFFTs.plan_bfft!)
+    @testset "$(length(dim))D" for dim in [(8,), (8, 32), (8, 32, 64)]
+        test_plan(_Plan(), AbstractFFTs.plan_fft, dim, ComplexF32, AbstractFFTs.plan_ifft)
+        test_plan(_Plan(), AbstractFFTs.plan_fft, dim, ComplexF32, AbstractFFTs.plan_bfft)
+        test_plan(_Plan(), AbstractFFTs.plan_fft, dim, Float32, AbstractFFTs.plan_ifft)
+        test_plan(_Plan(), AbstractFFTs.plan_fft, dim, Float32, AbstractFFTs.plan_bfft)
+        test_plan(_Plan(), AbstractFFTs.plan_rfft, dim, Float32)
+        test_plan(_Plan(), AbstractFFTs.plan_fft!, dim, ComplexF32, AbstractFFTs.plan_bfft!)
         # Not part of FFTW
-        # test_plan(AbstractFFTs.plan_rfft!, Float64)
-        test_plan(_FFT(), AbstractFFTs.fft, dim, ComplexF64, AbstractFFTs.ifft)
-        test_plan(_FFT(), AbstractFFTs.fft, dim, ComplexF64, AbstractFFTs.bfft)
+        # test_plan(AbstractFFTs.plan_rfft!, Float32)
+        test_plan(_FFT(), AbstractFFTs.fft, dim, ComplexF32, AbstractFFTs.ifft)
+        test_plan(_FFT(), AbstractFFTs.fft, dim, ComplexF32, AbstractFFTs.bfft)
         if length(dim) == 1  # irfft/brfft only for 1D
-            test_plan(_FFT(), AbstractFFTs.rfft, dim, Float64, AbstractFFTs.irfft)
-            test_plan(_FFT(), AbstractFFTs.rfft, dim, Float64, AbstractFFTs.brfft)
+            test_plan(_FFT(), AbstractFFTs.rfft, dim, Float32, AbstractFFTs.irfft)
+            test_plan(_FFT(), AbstractFFTs.rfft, dim, Float32, AbstractFFTs.brfft)
+        end
+        if (ComplexF64 in eltypes) && (Float64 in eltypes)
+            test_plan(_Plan(), AbstractFFTs.plan_fft, dim, ComplexF64, AbstractFFTs.plan_ifft)
+            test_plan(_Plan(), AbstractFFTs.plan_fft, dim, ComplexF64, AbstractFFTs.plan_bfft)
+            test_plan(_Plan(), AbstractFFTs.plan_fft, dim, Float64, AbstractFFTs.plan_ifft)
+            test_plan(_Plan(), AbstractFFTs.plan_fft, dim, Float64, AbstractFFTs.plan_bfft)
+            test_plan(_Plan(), AbstractFFTs.plan_rfft, dim, Float64)
+            test_plan(_Plan(), AbstractFFTs.plan_fft!, dim, ComplexF64, AbstractFFTs.plan_bfft!)
+            # Not part of FFTW
+            # test_plan(AbstractFFTs.plan_rfft!, Float64)
+            test_plan(_FFT(), AbstractFFTs.fft, dim, ComplexF64, AbstractFFTs.ifft)
+            test_plan(_FFT(), AbstractFFTs.fft, dim, ComplexF64, AbstractFFTs.bfft)
+            if length(dim) == 1  # irfft/brfft only for 1D
+                test_plan(_FFT(), AbstractFFTs.rfft, dim, Float64, AbstractFFTs.irfft)
+                test_plan(_FFT(), AbstractFFTs.rfft, dim, Float64, AbstractFFTs.brfft)
+            end
         end
     end
 end
-end

}
*out = desc;
return 0;
} catch (...) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a catch-all.

@codecov
Copy link

codecov bot commented Sep 4, 2025

Codecov Report

❌ Patch coverage is 60.79137% with 109 lines in your changes missing coverage. Please review.
✅ Project coverage is 79.70%. Comparing base (3e30673) to head (e19e950).
⚠️ Report is 1 commits behind head on master.

Files with missing lines Patch % Lines
lib/mkl/fft.jl 60.79% 109 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master     #515      +/-   ##
==========================================
- Coverage   81.73%   79.70%   -2.04%     
==========================================
  Files          44       45       +1     
  Lines        2540     2818     +278     
==========================================
+ Hits         2076     2246     +170     
- Misses        464      572     +108     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

lib/mkl/fft.jl Outdated
Comment on lines 58 to 69
ccall_create1d(desc_ref, prec::Int32, dom::Int32, length::Int64) = ccall((:onemklDftCreate1D, lib), Cint, (Ref{Ptr{Cvoid}}, Cint, Cint, Int64), desc_ref, prec, dom, length)
ccall_creatend(desc_ref, prec::Int32, dom::Int32, dim::Int64, lengths::Ptr{Int64}) = ccall((:onemklDftCreateND, lib), Cint, (Ref{Ptr{Cvoid}}, Cint, Cint, Int64, Ptr{Int64}), desc_ref, prec, dom, dim, lengths)
ccall_destroy(desc) = ccall((:onemklDftDestroy, lib), Cint, (Ptr{Cvoid},), desc)
ccall_commit(desc, q) = ccall((:onemklDftCommit, lib), Cint, (Ptr{Cvoid}, syclQueue_t), desc, q)
ccall_fwd(desc, ptr) = ccall((:onemklDftComputeForward, lib), Cint, (Ptr{Cvoid}, Ptr{Cvoid}), desc, ptr)
ccall_fwd_oop(desc, pin, pout) = ccall((:onemklDftComputeForwardOutOfPlace, lib), Cint, (Ptr{Cvoid}, Ptr{Cvoid}, Ptr{Cvoid}), desc, pin, pout)
ccall_bwd(desc, ptr) = ccall((:onemklDftComputeBackward, lib), Cint, (Ptr{Cvoid}, Ptr{Cvoid}), desc, ptr)
ccall_bwd_oop(desc, pin, pout) = ccall((:onemklDftComputeBackwardOutOfPlace, lib), Cint, (Ptr{Cvoid}, Ptr{Cvoid}, Ptr{Cvoid}), desc, pin, pout)
ccall_set_double(desc, param::Int32, value::Float64) = ccall((:onemklDftSetValueDouble, lib), Cint, (Ptr{Cvoid}, Cint, Float64), desc, param, value)
ccall_set_int(desc, param::Int32, value::Int64) = ccall((:onemklDftSetValueInt64, lib), Cint, (Ptr{Cvoid}, Cint, Int64), desc, param, value)
ccall_set_int64_array(desc, param::Int32, values::Vector{Int64}) = ccall((:onemklDftSetValueInt64Array, lib), Cint, (Ptr{Cvoid}, Cint, Ptr{Int64}, Int64), desc, param, pointer(values), length(values))
ccall_set_cfg(desc, param::Int32, value::Int32) = ccall((:onemklDftSetValueConfigValue, lib), Cint, (Ptr{Cvoid}, Cint, Cint), desc, param, value)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@michel2323 Please use the wrappers generated by Clang.jl.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is generated with Clang. I'm not sure I understand.

joinpath(dirname(@__DIR__), "deps", "src", "onemkl_dft.h");

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see now.

lib/mkl/fft.jl Outdated
R = length(region); reg = NTuple{R,Int}(region)
# Only support single dimension transforms for now
if R != 1
error("Multi-dimensional real FFT not yet supported")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@michel2323 Do we know if it is feature not yet implemented by Intel or something to improve on our side?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I get wrong values. Maybe I did something wrong, or the Intel library is wrong.

Copy link
Member

@amontoison amontoison left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@michel2323 Please address the two comments and it should be good for me.

Copy link
Member

@amontoison amontoison left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM (if all tests passed)!

@michel2323 michel2323 merged commit 3d3278d into master Sep 16, 2025
2 checks passed
@michel2323 michel2323 deleted the ms/fft branch September 16, 2025 16:03
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants