Skip to content

Commit cf15e34

Browse files
committed
Improve performance by fusing matrix products
1 parent 4bda01c commit cf15e34

File tree

1 file changed

+23
-11
lines changed

1 file changed

+23
-11
lines changed

src/matmul.jl

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ function _mul!(::MatMulMode{:slow}, C, A::AbstractMatrix, B::AbstractVecOrMat,
207207
C .*= α
208208
else
209209
AB = Matrix{eltype(C)}(undef, size(A, 1), size(B, 2))
210-
C .= AB .* α .+ C .* β
210+
C .= _matmul_rec!(AB, A, B) .* α .+ C .* β
211211
end
212212
end
213213
t = all(isguaranteed, A) & all(isguaranteed, B) & isguaranteed(α) & isguaranteed(β)
@@ -444,13 +444,9 @@ function __mul(A::AbstractMatrix{Interval{T}}, B::AbstractVecOrMat{Interval{T}})
444444
mA, rA = _vec_or_mat_midradius(A)
445445
mB, rB = _vec_or_mat_midradius(B)
446446

447-
cache_1 = Matrix{T}(undef, size(A, 1), size(B, 2))
448-
cache_2 = Matrix{T}(undef, size(A, 1), size(B, 2))
449-
450-
ρA = sign.(mA) .* min.(abs.(mA), rA)
451-
ρB = sign.(mB) .* min.(abs.(mB), rB)
452-
mC = _matmul_rec!(cache_1, mA, mB) + _matmul_rec!(cache_2, ρA, ρB)
453-
μ = _matmul_rec!(cache_1, abs.(mA), abs.(mB)) + _matmul_rec!(cache_2, abs.(ρA), abs.(ρB))
447+
cache_1 = zeros(T, size(A, 1), size(B, 2))
448+
cache_2 = zeros(T, size(A, 1), size(B, 2))
449+
mC, μ = _fused_matmul!(cache_1, cache_2, mA, rA, mB, rB)
454450

455451
γ = _add_round.(_mul_round.(convert(T, k + 1), eps.(μ), RoundUp), IntervalArithmetic._mul_round(IntervalArithmetic._inv_round(u2, RoundUp), floatmin(T), RoundUp), RoundUp)
456452

@@ -469,6 +465,23 @@ function _vec_or_mat_midradius(A::AbstractVecOrMat{Interval{T}}) where {T<:Abstr
469465
return mA, rA
470466
end
471467

468+
function _fused_matmul!(mC, μ, mA, rA, mB, rB)
469+
Threads.@threads for j axes(mB, 2)
470+
for l axes(mA, 2)
471+
@inbounds for i axes(mA, 1)
472+
a, c = mA[i,l], rA[i,l]
473+
b, d = mB[l,j], rB[l,j]
474+
e = sign(a) * min(abs(a), c)
475+
f = sign(b) * min(abs(b), d)
476+
p = a*b + e*f
477+
mC[i,j] += p
478+
μ[i,j] += abs(p)
479+
end
480+
end
481+
end
482+
return mC, μ
483+
end
484+
472485
#-
473486

474487
let fenv_consts = Vector{Cint}(undef, 9)
@@ -494,9 +507,6 @@ else
494507
end
495508

496509
function _call_gem_openblas_upward!(C, A::AbstractMatrix{Float64}, B::AbstractMatrix{Float64})
497-
prev_rounding = _getrounding() # save current rounding mode
498-
_setrounding(JL_FE_UPWARD) # set rounding mode to upward
499-
500510
m, k = size(A)
501511
n = size(B, 2)
502512

@@ -506,6 +516,8 @@ function _call_gem_openblas_upward!(C, A::AbstractMatrix{Float64}, B::AbstractMa
506516
transA = 'N'
507517
transB = 'N'
508518

519+
prev_rounding = _getrounding() # save current rounding mode
520+
_setrounding(JL_FE_UPWARD) # set rounding mode to upward
509521
try
510522
ccall((:dgemm_64_, OpenBLASConsistentFPCSR_jll.libopenblas), Cvoid,
511523
(Ref{UInt8}, Ref{UInt8}, Ref{LinearAlgebra.BLAS.BlasInt}, Ref{LinearAlgebra.BLAS.BlasInt}, Ref{LinearAlgebra.BLAS.BlasInt},

0 commit comments

Comments
 (0)