Skip to content

Commit 45c79e9

Browse files
authored
Implement sensitivities for BLAS.gemm (#25)
These are ported from Nabla.
1 parent 6d2be82 commit 45c79e9

File tree

2 files changed

+63
-0
lines changed

2 files changed

+63
-0
lines changed

src/rules/blas.jl

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@ These implementations were ported from the wonderful DiffLinearAlgebra
33
package (https://github.com/invenia/DiffLinearAlgebra.jl).
44
=#
55

6+
using LinearAlgebra: BlasFloat
7+
using LinearAlgebra.BLAS: gemm
8+
69
_zeros(x) = fill!(similar(x), zero(eltype(x)))
710

811
_rule_via(∂) = Rule(ΔΩ -> isa(ΔΩ, Zero) ? ΔΩ : (extern(ΔΩ)))
@@ -72,3 +75,37 @@ function rrule(f::typeof(BLAS.gemv), tA, A, x)
7275
Ω, (dtA, dα, dA, dx) = rrule(f, tA, one(eltype(A)), A, x)
7376
return Ω, (dtA, dA, dx)
7477
end
78+
79+
#####
80+
##### `BLAS.gemm`
81+
#####
82+
83+
function rrule(::typeof(gemm), tA::Char, tB::Char, α::T,
84+
A::AbstractMatrix{T}, B::AbstractMatrix{T}) where T<:BlasFloat
85+
C = gemm(tA, tB, α, A, B)
86+
∂α =-> sum(C̄ .* C) / α
87+
if uppercase(tA) === 'N'
88+
if uppercase(tB) === 'N'
89+
∂A =-> gemm('N', 'T', α, C̄, B)
90+
∂B =-> gemm('T', 'N', α, A, C̄)
91+
else
92+
∂A =-> gemm('N', 'N', α, C̄, B)
93+
∂B =-> gemm('T', 'N', α, C̄, A)
94+
end
95+
else
96+
if uppercase(tB) === 'N'
97+
∂A =-> gemm('N', 'T', α, B, C̄)
98+
∂B =-> gemm('N', 'N', α, A, C̄)
99+
else
100+
∂A =-> gemm('T', 'T', α, B, C̄)
101+
∂B =-> gemm('T', 'T', α, C̄, A)
102+
end
103+
end
104+
return C, (DNERule(), DNERule(), _rule_via(∂α), _rule_via(∂A), _rule_via(∂B))
105+
end
106+
107+
function rrule(::typeof(gemm), tA::Char, tB::Char,
108+
A::AbstractMatrix{T}, B::AbstractMatrix{T}) where T<:BlasFloat
109+
C, (dtA, dtB, _, dA, dB) = rrule(gemm, tA, tB, one(T), A, B)
110+
return C, (dtA, dtB, dA, dB)
111+
end

test/rules/blas.jl

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
using LinearAlgebra.BLAS: gemm
2+
3+
@testset "BLAS" begin
4+
@testset "gemm" begin
5+
rng = MersenneTwister(1)
6+
dims = 3:5
7+
for m in dims, n in dims, p in dims, tA in ('N', 'T'), tB in ('N', 'T')
8+
α = randn(rng)
9+
A = randn(rng, tA === 'N' ? (m, n) : (n, m))
10+
B = randn(rng, tB === 'N' ? (n, p) : (p, n))
11+
C = gemm(tA, tB, α, A, B)
12+
fAB, (dtA, dtB, dα, dA, dB) = rrule(gemm, tA, tB, α, A, B)
13+
@test C fAB
14+
@test dtA isa ChainRules.DNERule
15+
@test dtB isa ChainRules.DNERule
16+
for (f, x, dx) in [(X->gemm(tA, tB, X, A, B), α, dα),
17+
(X->gemm(tA, tB, α, X, B), A, dA),
18+
(X->gemm(tA, tB, α, A, X), B, dB)]
19+
= randn(rng, size(C)...)
20+
x̄_ad = dx(ȳ)
21+
x̄_fd = j′vp(central_fdm(5, 1), f, ȳ, x)
22+
@test x̄_ad x̄_fd rtol=1e-9 atol=1e-9
23+
end
24+
end
25+
end
26+
end

0 commit comments

Comments
 (0)