-
Notifications
You must be signed in to change notification settings - Fork 57
Expand file tree
/
Copy pathlinalg.jl
More file actions
129 lines (116 loc) · 4.81 KB
/
linalg.jl
File metadata and controls
129 lines (116 loc) · 4.81 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
# Linear Algebra chainrules
# -------------------------
function ChainRulesCore.rrule(::typeof(+), a::AbstractTensorMap, b::AbstractTensorMap)
plus_pullback(Δc) = NoTangent(), Δc, Δc
return a + b, plus_pullback
end
ChainRulesCore.rrule(::typeof(-), a::AbstractTensorMap) = -a, Δc -> (NoTangent(), -Δc)
function ChainRulesCore.rrule(::typeof(-), a::AbstractTensorMap, b::AbstractTensorMap)
minus_pullback(Δc) = NoTangent(), Δc, -Δc
return a - b, minus_pullback
end
function ChainRulesCore.rrule(::typeof(*), a::AbstractTensorMap, b::AbstractTensorMap)
times_pullback(Δc) = NoTangent(), @thunk(Δc * b'), @thunk(a' * Δc)
return a * b, times_pullback
end
function ChainRulesCore.rrule(::typeof(*), a::AbstractTensorMap, b::Number)
times_pullback(Δc) = NoTangent(), @thunk(Δc * b'), @thunk(dot(a, Δc))
return a * b, times_pullback
end
function ChainRulesCore.rrule(::typeof(*), a::Number, b::AbstractTensorMap)
times_pullback(Δc) = NoTangent(), @thunk(dot(b, Δc)), @thunk(a' * Δc)
return a * b, times_pullback
end
function ChainRulesCore.rrule(::typeof(⊗), A::AbstractTensorMap, B::AbstractTensorMap)
C = A ⊗ B
projectA = ProjectTo(A)
projectB = ProjectTo(B)
function otimes_pullback(ΔC_)
# TODO: this rule is probably better written in terms of inner products,
# using planarcontract and adjoint tensormaps would remove the twists.
ΔC = unthunk(ΔC_)
pΔC = ((codomainind(A)..., (domainind(A) .+ numout(B))...),
((codomainind(B) .+ numout(A))...,
(domainind(B) .+ (numin(A) + numout(A)))...))
dA_ = @thunk let
ipA = (codomainind(A), domainind(A))
pB = (allind(B), ())
dA = zerovector(A, promote_contract(scalartype(ΔC), scalartype(B)))
tB = twist(B, filter(x -> isdual(space(B, x)), allind(B)))
dA = tensorcontract!(dA, ΔC, pΔC, false, tB, pB, true, ipA)
return projectA(dA)
end
dB_ = @thunk let
ipB = (codomainind(B), domainind(B))
pA = ((), allind(A))
dB = zerovector(B, promote_contract(scalartype(ΔC), scalartype(A)))
tA = twist(A, filter(x -> isdual(space(A, x)), allind(A)))
dB = tensorcontract!(dB, tA, pA, true, ΔC, pΔC, false, ipB)
return projectB(dB)
end
return NoTangent(), dA_, dB_
end
return C, otimes_pullback
end
function ChainRulesCore.rrule(::typeof(permute), tsrc::AbstractTensorMap, p::Index2Tuple;
copy::Bool=false)
function permute_pullback(Δtdst)
invp = TensorKit._canonicalize(TupleTools.invperm(linearize(p)), tsrc)
return NoTangent(), permute(unthunk(Δtdst), invp; copy=true), NoTangent()
end
return permute(tsrc, p; copy=true), permute_pullback
end
function ChainRulesCore.rrule(::typeof(tr), A::AbstractTensorMap)
tr_pullback(Δtr) = NoTangent(), Δtr * id(domain(A))
return tr(A), tr_pullback
end
function ChainRulesCore.rrule(::typeof(adjoint), A::AbstractTensorMap)
adjoint_pullback(Δadjoint) = NoTangent(), adjoint(unthunk(Δadjoint))
return adjoint(A), adjoint_pullback
end
function ChainRulesCore.rrule(::typeof(dot), a::AbstractTensorMap, b::AbstractTensorMap)
dot_pullback(Δd) = NoTangent(), @thunk(b * Δd'), @thunk(a * Δd)
return dot(a, b), dot_pullback
end
function ChainRulesCore.rrule(::typeof(norm), a::AbstractTensorMap, p::Real=2)
p == 2 || error("currently only implemented for p = 2")
n = norm(a, p)
function norm_pullback(Δn)
return NoTangent(), a * (Δn' + Δn) / 2 / hypot(n, eps(one(n))), NoTangent()
end
return n, norm_pullback
end
function ChainRulesCore.rrule(::typeof(real), a::AbstractTensorMap)
a_real = real(a)
real_pullback(Δa) = NoTangent(), eltype(a) <: Real ? Δa : complex(unthunk(Δa))
return a_real, real_pullback
end
function ChainRulesCore.rrule(::typeof(imag), a::AbstractTensorMap)
a_imag = imag(a)
function imag_pullback(Δa)
Δa′ = unthunk(Δa)
return NoTangent(),
eltype(a) <: Real ? ZeroTangent() : complex(zerovector(Δa′), Δa′)
end
return a_imag, imag_pullback
end
function ChainRulesCore.rrule(cfg::RuleConfig, ::typeof(exp), A::AbstractTensorMap)
domain(A) == codomain(A) ||
error("Exponential of a tensor only exist when domain == codomain.")
P_A = ProjectTo(A)
C = similar(A)
pullbacks = map(blocks(A)) do (c, b)
expB, pullback = rrule_via_ad(cfg, exp, b)
copy!(block(C, c), expB)
return c => pullback
end
function exp_pullback(ΔC_)
ΔC = unthunk(ΔC_)
dA = similar(A)
for (c, pb) in pullbacks
copy!(block(dA, c), last(pb(block(ΔC, c))))
end
return NoTangent(), P_A(dA)
end
return C, exp_pullback
end