From 0046fa00dec938ed6303f643153a82f9c1868635 Mon Sep 17 00:00:00 2001 From: Anubhab Haldar Date: Wed, 29 Jun 2022 11:04:48 -0400 Subject: [PATCH 1/3] Barebones behaviour parity --- src/complex_matmul.jl | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/src/complex_matmul.jl b/src/complex_matmul.jl index dc92bb1..d7925d2 100644 --- a/src/complex_matmul.jl +++ b/src/complex_matmul.jl @@ -25,6 +25,34 @@ for AT in [:AbstractVector, :AbstractMatrix] # to avoid ambiguity error _C end + function _matmul_v2!(_C::$AT{Complex{T}}, _A::AbstractMatrix{Complex{U}}, _B::$AT{Complex{V}}, + α=One(), β=Zero(), nthread::Nothing=nothing, MKN=nothing, contig_axis=nothing) where {T,U,V} + # C, A, B = map(real_rep, (_C, _A, _B)) + C = reinterpret(T, _C) + A = reinterpret(T, _A) + B = reinterpret(reshape, T, _B) + + η = ifelse(ArrayInterface.is_lazy_conjugate(_A), StaticInt(-1), StaticInt(1)) + θ = ifelse(ArrayInterface.is_lazy_conjugate(_B), StaticInt(-1), StaticInt(1)) + (+ᶻ, -ᶻ) = ifelse(ArrayInterface.is_lazy_conjugate(_C), (-, +), (+, -)) + ηθ = η*θ + + @tturbo for n ∈ indices((C, B), (2,3)), m ∈ indices((C, A), 1) + Cmn = zero(T) + for k ∈ indices((A, B), (2, 2)) + Amk = A[m,k] + Aperm = vpermilps177(Amk) + Cmn = vfmaddsub(Amk, B[1,k,n], vfmaddsub(Aperm, B[2,k,n], Cmn)) + # Cmn_re += A[1, m, k] * B[1, k, n] - ηθ * A[2, m, k] * B[2, k, n] + # Cmn_im += θ * A[1, m, k] * B[2, k, n] + η * A[2, m, k] * B[1, k, n] + end + # C[1,m,n] = (real(α) * Cmn_re -ᶻ imag(α) * Cmn_im) + (real(β) * C[1,m,n] -ᶻ imag(β) * C[2,m,n]) + # C[2,m,n] = (imag(α) * Cmn_re +ᶻ real(α) * Cmn_im) + (imag(β) * C[1,m,n] +ᶻ real(β) * C[2,m,n]) + C[m, n] = Cmn + end + _C + end + @inline function _matmul!(_C::$AT{Complex{T}}, A::AbstractMatrix{U}, _B::$AT{Complex{V}}, α=One(), β=Zero(), nthread::Nothing=nothing, MKN=nothing, contig_axis=nothing) where {T,U,V} C, B = map(real_rep, (_C, _B)) From 2b134e56e1a773b9c5f1d975cf6e2d14f0935667 Mon Sep 17 00:00:00 2001 From: Anubhab Haldar Date: Wed, 29 Jun 2022 23:15:15 -0400 Subject: [PATCH 2/3] Kernels in place, can't figure out dynamic kernel selection, sometimes NaNs... --- src/complex_matmul.jl | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/src/complex_matmul.jl b/src/complex_matmul.jl index d7925d2..8809135 100644 --- a/src/complex_matmul.jl +++ b/src/complex_matmul.jl @@ -30,28 +30,40 @@ for AT in [:AbstractVector, :AbstractMatrix] # to avoid ambiguity error # C, A, B = map(real_rep, (_C, _A, _B)) C = reinterpret(T, _C) A = reinterpret(T, _A) - B = reinterpret(reshape, T, _B) + B = real_rep(_B) η = ifelse(ArrayInterface.is_lazy_conjugate(_A), StaticInt(-1), StaticInt(1)) θ = ifelse(ArrayInterface.is_lazy_conjugate(_B), StaticInt(-1), StaticInt(1)) (+ᶻ, -ᶻ) = ifelse(ArrayInterface.is_lazy_conjugate(_C), (-, +), (+, -)) ηθ = η*θ - @tturbo for n ∈ indices((C, B), (2,3)), m ∈ indices((C, A), 1) + signs = Vec(ntuple(x -> ifelse(iseven(x), -one(T), one(T)), pick_vector_width(Float64))...) + + @tturbo vectorize=2 for n ∈ indices((C, B), (2,3)), m ∈ indices((C, A), 1) Cmn = zero(T) for k ∈ indices((A, B), (2, 2)) Amk = A[m,k] Aperm = vpermilps177(Amk) + + # TODO: I don't yet know how to pick the correct branch + # based on η and θ. + # A B Cmn = vfmaddsub(Amk, B[1,k,n], vfmaddsub(Aperm, B[2,k,n], Cmn)) + # A^* B + # Cmn = signs * vfmsubadd(Amk, B[1,k,n], vfmadd(Aperm, B[2,k,n], Cmn)) + # A B^* + # Cmn = vfmaddsub(Amk, B[1,k,n], vfmaddsub(-Aperm, B[2,k,n], Cmn)) + # A^* B^* + # Cmn = signs * vfmaddsub(Amk, B[1,k,n], vfmsub(Aperm, B[2,k,n], Cmn)) + # Cmn_re += A[1, m, k] * B[1, k, n] - ηθ * A[2, m, k] * B[2, k, n] # Cmn_im += θ * A[1, m, k] * B[2, k, n] + η * A[2, m, k] * B[1, k, n] end # C[1,m,n] = (real(α) * Cmn_re -ᶻ imag(α) * Cmn_im) + (real(β) * C[1,m,n] -ᶻ imag(β) * C[2,m,n]) # C[2,m,n] = (imag(α) * Cmn_re +ᶻ real(α) * Cmn_im) + (imag(β) * C[1,m,n] +ᶻ real(β) * C[2,m,n]) C[m, n] = Cmn - end - _C end + _C @inline function _matmul!(_C::$AT{Complex{T}}, A::AbstractMatrix{U}, _B::$AT{Complex{V}}, α=One(), β=Zero(), nthread::Nothing=nothing, MKN=nothing, contig_axis=nothing) where {T,U,V} From 54116ebdf7d8cc6d463716ca4c809b35f22ccffd Mon Sep 17 00:00:00 2001 From: Anubhab Haldar Date: Mon, 1 Aug 2022 00:23:00 -0400 Subject: [PATCH 3/3] Just straight up duplicate... --- src/complex_matmul.jl | 103 ++++++++++++++++++++++++++++++++++++++---- 1 file changed, 95 insertions(+), 8 deletions(-) diff --git a/src/complex_matmul.jl b/src/complex_matmul.jl index 8809135..1148dfc 100644 --- a/src/complex_matmul.jl +++ b/src/complex_matmul.jl @@ -32,21 +32,25 @@ for AT in [:AbstractVector, :AbstractMatrix] # to avoid ambiguity error A = reinterpret(T, _A) B = real_rep(_B) - η = ifelse(ArrayInterface.is_lazy_conjugate(_A), StaticInt(-1), StaticInt(1)) - θ = ifelse(ArrayInterface.is_lazy_conjugate(_B), StaticInt(-1), StaticInt(1)) + η_bool = ifelse(ArrayInterface.is_lazy_conjugate(_A), StaticInt(-1), StaticInt(1)) + θ_bool = ifelse(ArrayInterface.is_lazy_conjugate(_B), StaticInt(-1), StaticInt(1)) (+ᶻ, -ᶻ) = ifelse(ArrayInterface.is_lazy_conjugate(_C), (-, +), (+, -)) - ηθ = η*θ + # ηθ = η*θ signs = Vec(ntuple(x -> ifelse(iseven(x), -one(T), one(T)), pick_vector_width(Float64))...) + if !η_bool & !θ_bool + cmatmul_ab(C, A, B) + + _C + end + function cmatmul_ab!(C, A, B) @tturbo vectorize=2 for n ∈ indices((C, B), (2,3)), m ∈ indices((C, A), 1) Cmn = zero(T) for k ∈ indices((A, B), (2, 2)) Amk = A[m,k] Aperm = vpermilps177(Amk) - # TODO: I don't yet know how to pick the correct branch - # based on η and θ. # A B Cmn = vfmaddsub(Amk, B[1,k,n], vfmaddsub(Aperm, B[2,k,n], Cmn)) # A^* B @@ -59,11 +63,94 @@ for AT in [:AbstractVector, :AbstractMatrix] # to avoid ambiguity error # Cmn_re += A[1, m, k] * B[1, k, n] - ηθ * A[2, m, k] * B[2, k, n] # Cmn_im += θ * A[1, m, k] * B[2, k, n] + η * A[2, m, k] * B[1, k, n] end - # C[1,m,n] = (real(α) * Cmn_re -ᶻ imag(α) * Cmn_im) + (real(β) * C[1,m,n] -ᶻ imag(β) * C[2,m,n]) - # C[2,m,n] = (imag(α) * Cmn_re +ᶻ real(α) * Cmn_im) + (imag(β) * C[1,m,n] +ᶻ real(β) * C[2,m,n]) + C[1,m,n] = (real(α) * Cmn_re -ᶻ imag(α) * Cmn_im) + (real(β) * C[1,m,n] -ᶻ imag(β) * C[2,m,n]) + C[2,m,n] = (imag(α) * Cmn_re +ᶻ real(α) * Cmn_im) + (imag(β) * C[1,m,n] +ᶻ real(β) * C[2,m,n]) C[m, n] = Cmn end - _C + end + + function cmatmul_astarb() + + signs = Vec(ntuple(x -> ifelse(iseven(x), -one(T), one(T)), pick_vector_width(Float64))...) + + @tturbo vectorize=2 for n ∈ indices((C, B), (2,3)), m ∈ indices((C, A), 1) + Cmn = zero(T) + for k ∈ indices((A, B), (2, 2)) + Amk = A[m,k] + Aperm = vpermilps177(Amk) + + # A B + # Cmn = vfmaddsub(Amk, B[1,k,n], vfmaddsub(Aperm, B[2,k,n], Cmn)) + # A^* B + Cmn = signs * vfmsubadd(Amk, B[1,k,n], vfmadd(Aperm, B[2,k,n], Cmn)) + # A B^* + # Cmn = vfmaddsub(Amk, B[1,k,n], vfmaddsub(-Aperm, B[2,k,n], Cmn)) + # A^* B^* + # Cmn = signs * vfmaddsub(Amk, B[1,k,n], vfmsub(Aperm, B[2,k,n], Cmn)) + + # Cmn_re += A[1, m, k] * B[1, k, n] - ηθ * A[2, m, k] * B[2, k, n] + # Cmn_im += θ * A[1, m, k] * B[2, k, n] + η * A[2, m, k] * B[1, k, n] + end + C[1,m,n] = (real(α) * Cmn_re -ᶻ imag(α) * Cmn_im) + (real(β) * C[1,m,n] -ᶻ imag(β) * C[2,m,n]) + C[2,m,n] = (imag(α) * Cmn_re +ᶻ real(α) * Cmn_im) + (imag(β) * C[1,m,n] +ᶻ real(β) * C[2,m,n]) + C[m, n] = Cmn + end + end + + function cmatmul_abstar() + @tturbo vectorize=2 for n ∈ indices((C, B), (2,3)), m ∈ indices((C, A), 1) + Cmn = zero(T) + for k ∈ indices((A, B), (2, 2)) + Amk = A[m,k] + Aperm = vpermilps177(Amk) + + # TODO: I don't yet know how to pick the correct branch + # based on η and θ. + # A B + # Cmn = vfmaddsub(Amk, B[1,k,n], vfmaddsub(Aperm, B[2,k,n], Cmn)) + # A^* B + # Cmn = signs * vfmsubadd(Amk, B[1,k,n], vfmadd(Aperm, B[2,k,n], Cmn)) + # A B^* + Cmn = vfmaddsub(Amk, B[1,k,n], vfmaddsub(-Aperm, B[2,k,n], Cmn)) + # A^* B^* + # Cmn = signs * vfmaddsub(Amk, B[1,k,n], vfmsub(Aperm, B[2,k,n], Cmn)) + + # Cmn_re += A[1, m, k] * B[1, k, n] - ηθ * A[2, m, k] * B[2, k, n] + # Cmn_im += θ * A[1, m, k] * B[2, k, n] + η * A[2, m, k] * B[1, k, n] + end + C[1,m,n] = (real(α) * Cmn_re -ᶻ imag(α) * Cmn_im) + (real(β) * C[1,m,n] -ᶻ imag(β) * C[2,m,n]) + C[2,m,n] = (imag(α) * Cmn_re +ᶻ real(α) * Cmn_im) + (imag(β) * C[1,m,n] +ᶻ real(β) * C[2,m,n]) + C[m, n] = Cmn + end + end + + function cmatmul_astarbstar() + + signs = Vec(ntuple(x -> ifelse(iseven(x), -one(T), one(T)), pick_vector_width(Float64))...) + + @tturbo vectorize=2 for n ∈ indices((C, B), (2,3)), m ∈ indices((C, A), 1) + Cmn = zero(T) + for k ∈ indices((A, B), (2, 2)) + Amk = A[m,k] + Aperm = vpermilps177(Amk) + + # A B + # Cmn = vfmaddsub(Amk, B[1,k,n], vfmaddsub(Aperm, B[2,k,n], Cmn)) + # A^* B + # Cmn = signs * vfmsubadd(Amk, B[1,k,n], vfmadd(Aperm, B[2,k,n], Cmn)) + # A B^* + # Cmn = vfmaddsub(Amk, B[1,k,n], vfmaddsub(-Aperm, B[2,k,n], Cmn)) + # A^* B^* + Cmn = signs * vfmaddsub(Amk, B[1,k,n], vfmsub(Aperm, B[2,k,n], Cmn)) + + # Cmn_re += A[1, m, k] * B[1, k, n] - ηθ * A[2, m, k] * B[2, k, n] + # Cmn_im += θ * A[1, m, k] * B[2, k, n] + η * A[2, m, k] * B[1, k, n] + end + C[1,m,n] = (real(α) * Cmn_re -ᶻ imag(α) * Cmn_im) + (real(β) * C[1,m,n] -ᶻ imag(β) * C[2,m,n]) + C[2,m,n] = (imag(α) * Cmn_re +ᶻ real(α) * Cmn_im) + (imag(β) * C[1,m,n] +ᶻ real(β) * C[2,m,n]) + C[m, n] = Cmn + end + end @inline function _matmul!(_C::$AT{Complex{T}}, A::AbstractMatrix{U}, _B::$AT{Complex{V}}, α=One(), β=Zero(), nthread::Nothing=nothing, MKN=nothing, contig_axis=nothing) where {T,U,V}