@@ -3,6 +3,9 @@ These implementations were ported from the wonderful DiffLinearAlgebra
3
3
package (https://github.com/invenia/DiffLinearAlgebra.jl).
4
4
=#
5
5
6
+ using LinearAlgebra: BlasFloat
7
+ using LinearAlgebra. BLAS: gemm
8
+
6
9
_zeros (x) = fill! (similar (x), zero (eltype (x)))
7
10
8
11
_rule_via (∂) = Rule (ΔΩ -> isa (ΔΩ, Zero) ? ΔΩ : ∂ (extern (ΔΩ)))
@@ -72,3 +75,37 @@ function rrule(f::typeof(BLAS.gemv), tA, A, x)
72
75
Ω, (dtA, dα, dA, dx) = rrule (f, tA, one (eltype (A)), A, x)
73
76
return Ω, (dtA, dA, dx)
74
77
end
78
+
79
+ # ####
80
+ # #### `BLAS.gemm`
81
+ # ####
82
+
83
+ function rrule (:: typeof (gemm), tA:: Char , tB:: Char , α:: T ,
84
+ A:: AbstractMatrix{T} , B:: AbstractMatrix{T} ) where T<: BlasFloat
85
+ C = gemm (tA, tB, α, A, B)
86
+ ∂α = C̄ -> sum (C̄ .* C) / α
87
+ if uppercase (tA) === ' N'
88
+ if uppercase (tB) === ' N'
89
+ ∂A = C̄ -> gemm (' N' , ' T' , α, C̄, B)
90
+ ∂B = C̄ -> gemm (' T' , ' N' , α, A, C̄)
91
+ else
92
+ ∂A = C̄ -> gemm (' N' , ' N' , α, C̄, B)
93
+ ∂B = C̄ -> gemm (' T' , ' N' , α, C̄, A)
94
+ end
95
+ else
96
+ if uppercase (tB) === ' N'
97
+ ∂A = C̄ -> gemm (' N' , ' T' , α, B, C̄)
98
+ ∂B = C̄ -> gemm (' N' , ' N' , α, A, C̄)
99
+ else
100
+ ∂A = C̄ -> gemm (' T' , ' T' , α, B, C̄)
101
+ ∂B = C̄ -> gemm (' T' , ' T' , α, C̄, A)
102
+ end
103
+ end
104
+ return C, (DNERule (), DNERule (), _rule_via (∂α), _rule_via (∂A), _rule_via (∂B))
105
+ end
106
+
107
+ function rrule (:: typeof (gemm), tA:: Char , tB:: Char ,
108
+ A:: AbstractMatrix{T} , B:: AbstractMatrix{T} ) where T<: BlasFloat
109
+ C, (dtA, dtB, _, dA, dB) = rrule (gemm, tA, tB, one (T), A, B)
110
+ return C, (dtA, dtB, dA, dB)
111
+ end
0 commit comments