Skip to content

Commit 0cc5518

Browse files
authored
Scaling loop instead of broadcasting in strided matrix exp (#56463)
Firstly, this is easier to read. Secondly, this merges the two loops into one. Thirdly, this avoids the broadcasting latency. ```julia julia> using LinearAlgebra julia> A = rand(2,2); julia> @time LinearAlgebra.exp!(A); 0.952597 seconds (2.35 M allocations: 116.574 MiB, 2.67% gc time, 99.01% compilation time) # master 0.877404 seconds (2.17 M allocations: 106.293 MiB, 2.65% gc time, 99.99% compilation time) # this PR ``` The performance also improves as there are fewer allocations in the first branch (`opnorm(A, 1) <= 2.1`): ```julia julia> B = diagm(0=>im.*(float.(1:200))./200, 1=>(1:199)./400, -1=>(1:199)./400); julia> opnorm(B,1) 1.9875 julia> @Btime exp($B); 5.066 ms (30 allocations: 4.89 MiB) # nightly v"1.12.0-DEV.1581" 4.926 ms (27 allocations: 4.28 MiB) # this PR ```
1 parent cd748a5 commit 0cc5518

File tree

1 file changed

+22
-11
lines changed

1 file changed

+22
-11
lines changed

stdlib/LinearAlgebra/src/dense.jl

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -707,25 +707,32 @@ function exp!(A::StridedMatrix{T}) where T<:BlasFloat
707707
# Compute U and V: Even/odd terms in Padé numerator & denom
708708
# Expansion of k=1 in for loop
709709
P = A2
710-
U = mul!(C[4]*P, true, C[2]*I, true, true) #U = C[2]*I + C[4]*P
711-
V = mul!(C[3]*P, true, C[1]*I, true, true) #V = C[1]*I + C[3]*P
710+
U = similar(P)
711+
V = similar(P)
712+
for ind in CartesianIndices(P)
713+
U[ind] = C[4]*P[ind] + C[2]*I[ind]
714+
V[ind] = C[3]*P[ind] + C[1]*I[ind]
715+
end
712716
for k in 2:(div(length(C), 2) - 1)
713717
P *= A2
714-
for ind in eachindex(P)
718+
for ind in eachindex(P, U, V)
715719
U[ind] += C[2k + 2] * P[ind]
716720
V[ind] += C[2k + 1] * P[ind]
717721
end
718722
end
719723

720-
U = A * U
724+
# U = A * U, but we overwrite P to avoid an allocation
725+
mul!(P, A, U)
726+
# P may be seen as an alias for U in the following code
721727

722728
# Padé approximant: (V-U)\(V+U)
723-
tmp1, tmp2 = A, A2 # Reuse already allocated arrays
724-
for ind in eachindex(tmp1)
725-
tmp1[ind] = V[ind] - U[ind]
726-
tmp2[ind] = V[ind] + U[ind]
729+
VminU, VplusU = V, U # Reuse already allocated arrays
730+
for ind in eachindex(V, U)
731+
vi, ui = V[ind], P[ind]
732+
VminU[ind] = vi - ui
733+
VplusU[ind] = vi + ui
727734
end
728-
X = LAPACK.gesv!(tmp1, tmp2)[1]
735+
X = LAPACK.gesv!(VminU, VplusU)[1]
729736
else
730737
s = log2(nA/5.4) # power of 2 later reversed by squaring
731738
if s > 0
@@ -793,10 +800,14 @@ function exp!(A::StridedMatrix{T}) where T<:BlasFloat
793800
end
794801

795802
if ilo > 1 # apply lower permutations in reverse order
796-
for j in (ilo-1):-1:1; rcswap!(j, Int(scale[j]), X) end
803+
for j in (ilo-1):-1:1
804+
rcswap!(j, Int(scale[j]), X)
805+
end
797806
end
798807
if ihi < n # apply upper permutations in forward order
799-
for j in (ihi+1):n; rcswap!(j, Int(scale[j]), X) end
808+
for j in (ihi+1):n
809+
rcswap!(j, Int(scale[j]), X)
810+
end
800811
end
801812
X
802813
end

0 commit comments

Comments
 (0)