diff --git a/src/complex_matmul.jl b/src/complex_matmul.jl index dc92bb1..1148dfc 100644 --- a/src/complex_matmul.jl +++ b/src/complex_matmul.jl @@ -25,6 +25,133 @@ for AT in [:AbstractVector, :AbstractMatrix] # to avoid ambiguity error _C end + function _matmul_v2!(_C::$AT{Complex{T}}, _A::AbstractMatrix{Complex{U}}, _B::$AT{Complex{V}}, + α=One(), β=Zero(), nthread::Nothing=nothing, MKN=nothing, contig_axis=nothing) where {T,U,V} + # C, A, B = map(real_rep, (_C, _A, _B)) + C = reinterpret(T, _C) + A = reinterpret(T, _A) + B = real_rep(_B) + + η_bool = ifelse(ArrayInterface.is_lazy_conjugate(_A), StaticInt(-1), StaticInt(1)) + θ_bool = ifelse(ArrayInterface.is_lazy_conjugate(_B), StaticInt(-1), StaticInt(1)) + (+ᶻ, -ᶻ) = ifelse(ArrayInterface.is_lazy_conjugate(_C), (-, +), (+, -)) + # ηθ = η*θ + + signs = Vec(ntuple(x -> ifelse(iseven(x), -one(T), one(T)), pick_vector_width(Float64))...) + if !η_bool & !θ_bool + cmatmul_ab(C, A, B) + + _C + end + + function cmatmul_ab!(C, A, B) + @tturbo vectorize=2 for n ∈ indices((C, B), (2,3)), m ∈ indices((C, A), 1) + Cmn = zero(T) + for k ∈ indices((A, B), (2, 2)) + Amk = A[m,k] + Aperm = vpermilps177(Amk) + + # A B + Cmn = vfmaddsub(Amk, B[1,k,n], vfmaddsub(Aperm, B[2,k,n], Cmn)) + # A^* B + # Cmn = signs * vfmsubadd(Amk, B[1,k,n], vfmadd(Aperm, B[2,k,n], Cmn)) + # A B^* + # Cmn = vfmaddsub(Amk, B[1,k,n], vfmaddsub(-Aperm, B[2,k,n], Cmn)) + # A^* B^* + # Cmn = signs * vfmaddsub(Amk, B[1,k,n], vfmsub(Aperm, B[2,k,n], Cmn)) + + # Cmn_re += A[1, m, k] * B[1, k, n] - ηθ * A[2, m, k] * B[2, k, n] + # Cmn_im += θ * A[1, m, k] * B[2, k, n] + η * A[2, m, k] * B[1, k, n] + end + C[1,m,n] = (real(α) * Cmn_re -ᶻ imag(α) * Cmn_im) + (real(β) * C[1,m,n] -ᶻ imag(β) * C[2,m,n]) + C[2,m,n] = (imag(α) * Cmn_re +ᶻ real(α) * Cmn_im) + (imag(β) * C[1,m,n] +ᶻ real(β) * C[2,m,n]) + C[m, n] = Cmn + end + end + + function cmatmul_astarb() + + signs = Vec(ntuple(x -> ifelse(iseven(x), -one(T), one(T)), pick_vector_width(Float64))...) + + @tturbo vectorize=2 for n ∈ indices((C, B), (2,3)), m ∈ indices((C, A), 1) + Cmn = zero(T) + for k ∈ indices((A, B), (2, 2)) + Amk = A[m,k] + Aperm = vpermilps177(Amk) + + # A B + # Cmn = vfmaddsub(Amk, B[1,k,n], vfmaddsub(Aperm, B[2,k,n], Cmn)) + # A^* B + Cmn = signs * vfmsubadd(Amk, B[1,k,n], vfmadd(Aperm, B[2,k,n], Cmn)) + # A B^* + # Cmn = vfmaddsub(Amk, B[1,k,n], vfmaddsub(-Aperm, B[2,k,n], Cmn)) + # A^* B^* + # Cmn = signs * vfmaddsub(Amk, B[1,k,n], vfmsub(Aperm, B[2,k,n], Cmn)) + + # Cmn_re += A[1, m, k] * B[1, k, n] - ηθ * A[2, m, k] * B[2, k, n] + # Cmn_im += θ * A[1, m, k] * B[2, k, n] + η * A[2, m, k] * B[1, k, n] + end + C[1,m,n] = (real(α) * Cmn_re -ᶻ imag(α) * Cmn_im) + (real(β) * C[1,m,n] -ᶻ imag(β) * C[2,m,n]) + C[2,m,n] = (imag(α) * Cmn_re +ᶻ real(α) * Cmn_im) + (imag(β) * C[1,m,n] +ᶻ real(β) * C[2,m,n]) + C[m, n] = Cmn + end + end + + function cmatmul_abstar() + @tturbo vectorize=2 for n ∈ indices((C, B), (2,3)), m ∈ indices((C, A), 1) + Cmn = zero(T) + for k ∈ indices((A, B), (2, 2)) + Amk = A[m,k] + Aperm = vpermilps177(Amk) + + # TODO: I don't yet know how to pick the correct branch + # based on η and θ. + # A B + # Cmn = vfmaddsub(Amk, B[1,k,n], vfmaddsub(Aperm, B[2,k,n], Cmn)) + # A^* B + # Cmn = signs * vfmsubadd(Amk, B[1,k,n], vfmadd(Aperm, B[2,k,n], Cmn)) + # A B^* + Cmn = vfmaddsub(Amk, B[1,k,n], vfmaddsub(-Aperm, B[2,k,n], Cmn)) + # A^* B^* + # Cmn = signs * vfmaddsub(Amk, B[1,k,n], vfmsub(Aperm, B[2,k,n], Cmn)) + + # Cmn_re += A[1, m, k] * B[1, k, n] - ηθ * A[2, m, k] * B[2, k, n] + # Cmn_im += θ * A[1, m, k] * B[2, k, n] + η * A[2, m, k] * B[1, k, n] + end + C[1,m,n] = (real(α) * Cmn_re -ᶻ imag(α) * Cmn_im) + (real(β) * C[1,m,n] -ᶻ imag(β) * C[2,m,n]) + C[2,m,n] = (imag(α) * Cmn_re +ᶻ real(α) * Cmn_im) + (imag(β) * C[1,m,n] +ᶻ real(β) * C[2,m,n]) + C[m, n] = Cmn + end + end + + function cmatmul_astarbstar() + + signs = Vec(ntuple(x -> ifelse(iseven(x), -one(T), one(T)), pick_vector_width(Float64))...) + + @tturbo vectorize=2 for n ∈ indices((C, B), (2,3)), m ∈ indices((C, A), 1) + Cmn = zero(T) + for k ∈ indices((A, B), (2, 2)) + Amk = A[m,k] + Aperm = vpermilps177(Amk) + + # A B + # Cmn = vfmaddsub(Amk, B[1,k,n], vfmaddsub(Aperm, B[2,k,n], Cmn)) + # A^* B + # Cmn = signs * vfmsubadd(Amk, B[1,k,n], vfmadd(Aperm, B[2,k,n], Cmn)) + # A B^* + # Cmn = vfmaddsub(Amk, B[1,k,n], vfmaddsub(-Aperm, B[2,k,n], Cmn)) + # A^* B^* + Cmn = signs * vfmaddsub(Amk, B[1,k,n], vfmsub(Aperm, B[2,k,n], Cmn)) + + # Cmn_re += A[1, m, k] * B[1, k, n] - ηθ * A[2, m, k] * B[2, k, n] + # Cmn_im += θ * A[1, m, k] * B[2, k, n] + η * A[2, m, k] * B[1, k, n] + end + C[1,m,n] = (real(α) * Cmn_re -ᶻ imag(α) * Cmn_im) + (real(β) * C[1,m,n] -ᶻ imag(β) * C[2,m,n]) + C[2,m,n] = (imag(α) * Cmn_re +ᶻ real(α) * Cmn_im) + (imag(β) * C[1,m,n] +ᶻ real(β) * C[2,m,n]) + C[m, n] = Cmn + end + end + @inline function _matmul!(_C::$AT{Complex{T}}, A::AbstractMatrix{U}, _B::$AT{Complex{V}}, α=One(), β=Zero(), nthread::Nothing=nothing, MKN=nothing, contig_axis=nothing) where {T,U,V} C, B = map(real_rep, (_C, _B))