-
Notifications
You must be signed in to change notification settings - Fork 27
Expand file tree
/
Copy pathutil.jl
More file actions
258 lines (228 loc) · 7.63 KB
/
util.jl
File metadata and controls
258 lines (228 loc) · 7.63 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
# Get next and previous directional CTMRG environment index, respecting periodicity
_next(i, total) = mod1(i + 1, total)
_prev(i, total) = mod1(i - 1, total)
# Get next and previous coordinate (direction, row, column), given a direction and going around the environment clockwise
function _next_coordinate((dir, row, col), rowsize, colsize)
if dir == 1
return (_next(dir, 4), row, _next(col, colsize))
elseif dir == 2
return (_next(dir, 4), _next(row, rowsize), col)
elseif dir == 3
return (_next(dir, 4), row, _prev(col, colsize))
elseif dir == 4
return (_next(dir, 4), _prev(row, rowsize), col)
end
end
function _prev_coordinate((dir, row, col), rowsize, colsize)
if dir == 1
return (_prev(dir, 4), _next(row, rowsize), col)
elseif dir == 2
return (_prev(dir, 4), row, _prev(col, colsize))
elseif dir == 3
return (_prev(dir, 4), _prev(row, rowsize), col)
elseif dir == 4
return (_prev(dir, 4), row, _next(col, colsize))
end
end
# iterator over each coordinates
"""
eachcoordinate(x, [dirs=1:4])
Enumerate all (dir, row, col) pairs.
"""
function eachcoordinate end
@non_differentiable eachcoordinate(args...)
# Element-wise multiplication of TensorMaps respecting block structure
function _elementwise_mult(a₁::AbstractTensorMap, a₂::AbstractTensorMap)
dst = similar(a₁)
for (k, b) in blocks(dst)
copyto!(b, block(a₁, k) .* block(a₂, k))
end
return dst
end
_safe_pow(a::Number, pow::Real, tol::Real) = (pow < 0 && abs(a) < tol) ? zero(a) : a^pow
"""
sdiag_pow(s, pow::Real; tol::Real=eps(scalartype(s))^(3 / 4))
Compute `s^pow` for a diagonal matrix `s`.
"""
function sdiag_pow(s::DiagonalTensorMap, pow::Real; tol::Real=eps(scalartype(s))^(3 / 4))
# Relative tol w.r.t. largest singular value (use norm(∘, Inf) to make differentiable)
tol *= norm(s, Inf)
spow = DiagonalTensorMap(_safe_pow.(s.data, pow, tol), space(s, 1))
return spow
end
function sdiag_pow(
s::AbstractTensorMap{T,S,1,1}, pow::Real; tol::Real=eps(scalartype(s))^(3 / 4)
) where {T,S}
# Relative tol w.r.t. largest singular value (use norm(∘, Inf) to make differentiable)
tol *= norm(s, Inf)
spow = similar(s)
for (k, b) in blocks(s)
copyto!(
block(spow, k), LinearAlgebra.diagm(_safe_pow.(LinearAlgebra.diag(b), pow, tol))
)
end
return spow
end
function ChainRulesCore.rrule(
::typeof(sdiag_pow),
s::AbstractTensorMap,
pow::Real;
tol::Real=eps(scalartype(s))^(3 / 4),
)
tol *= norm(s, Inf)
spow = sdiag_pow(s, pow; tol)
spow_minus1_conj = scale!(sdiag_pow(s', pow - 1; tol), pow)
function sdiag_pow_pullback(c̄_)
c̄ = unthunk(c̄_)
return (ChainRulesCore.NoTangent(), _elementwise_mult(c̄, spow_minus1_conj))
end
return spow, sdiag_pow_pullback
end
"""
absorb_s(U::AbstractTensorMap, S::DiagonalTensorMap, V::AbstractTensorMap)
Given `tsvd` result `U`, `S` and `V`, absorb singular values `S` into `U` and `V` by:
```
U -> U * sqrt(S), V -> sqrt(S) * V
```
"""
function absorb_s(U::AbstractTensorMap, S::DiagonalTensorMap, V::AbstractTensorMap)
@assert !isdual(space(S, 1))
sqrt_S = sdiag_pow(S, 0.5)
return U * sqrt_S, sqrt_S * V
end
"""
flip_svd(u::AbstractTensorMap, s::DiagonalTensorMap, vh::AbstractTensorMap)
Given `tsvd` result `u ← s ← vh`, flip the arrow between the three tensors
to `u2 → s2 → vh2` such that
```
u * s * vh = (@tensor t2[-1; -2] := u2[-1; 1] * s2[1; 2] * vh2[2; -2])
```
"""
function flip_svd(u::AbstractTensorMap, s::DiagonalTensorMap, vh::AbstractTensorMap)
return flip(u, 2), DiagonalTensorMap(flip(s, (1, 2))), flip(vh, 1)
end
"""
twistdual(t::AbstractTensorMap, i)
twistdual!(t::AbstractTensorMap, i)
Twist the i-th leg of a tensor `t` if it represents a dual space.
"""
function twistdual!(t::AbstractTensorMap, i::Int)
isdual(space(t, i)) || return t
return twist!(t, i)
end
function twistdual!(t::AbstractTensorMap, is)
is′ = filter(i -> isdual(space(t, i)), is)
return twist!(t, is′)
end
twistdual(t::AbstractTensorMap, is) = twistdual!(copy(t), is)
"""
str(t)
Fermionic supertrace by using `@tensor`.
"""
str(t::AbstractTensorMap) = _str(BraidingStyle(sectortype(t)), t)
_str(::Bosonic, t::AbstractTensorMap) = tr(t)
@generated function _str(::Fermionic, t::AbstractTensorMap{<:Any,<:Any,N,N}) where {N}
tex = tensorexpr(:t, ntuple(identity, N), ntuple(identity, N))
return macroexpand(@__MODULE__, :(@tensor $tex))
end
"""
trmul(H, ρ)
Compute `tr(H * ρ)` without forming `H * ρ`.
"""
@generated function trmul(
H::AbstractTensorMap{<:Any,S,N,N}, ρ::AbstractTensorMap{<:Any,S,N,N}
) where {S,N}
Hex = tensorexpr(:H, ntuple(identity, N), ntuple(i -> i + N, N))
ρex = tensorexpr(:ρ, ntuple(i -> i + N, N), ntuple(identity, N))
return macroexpand(@__MODULE__, :(@tensor $Hex * $ρex))
end
# Check whether diagonals contain degenerate values up to absolute or relative tolerance
function is_degenerate_spectrum(
S; atol::Real=0, rtol::Real=atol > 0 ? 0 : sqrt(eps(scalartype(S)))
)
for (_, b) in blocks(S)
s = real(diag(b))
for i in 1:(length(s) - 1)
isapprox(s[i], s[i + 1]; atol, rtol) && return true
end
end
return false
end
# There are no rrules for rotl90 and rotr90 in ChainRules.jl
function ChainRulesCore.rrule(::typeof(rotl90), a::AbstractMatrix)
function rotl90_pullback(x)
if !iszero(x)
x = if x isa Tangent
ChainRulesCore.construct(typeof(a), ChainRulesCore.backing(x))
else
x
end
x = rotr90(x)
end
return NoTangent(), x
end
return rotl90(a), rotl90_pullback
end
function ChainRulesCore.rrule(::typeof(rotr90), a::AbstractMatrix)
function rotr90_pullback(x)
if !iszero(x)
x = if x isa Tangent
ChainRulesCore.construct(typeof(a), ChainRulesCore.backing(x))
else
x
end
x = rotl90(x)
end
return NoTangent(), x
end
return rotr90(a), rotr90_pullback
end
# Differentiable setindex! alternative
function _setindex(a::AbstractArray, v, args...)
b::typeof(a) = copy(a)
b[args...] = v
return b
end
function ChainRulesCore.rrule(::typeof(_setindex), a::AbstractArray, tv, args...)
t = _setindex(a, tv, args...)
function _setindex_pullback(v)
if iszero(v)
backwards_tv = ZeroTangent()
backwards_a = ZeroTangent()
else
v = if v isa Tangent
ChainRulesCore.construct(typeof(a), ChainRulesCore.backing(v))
else
v
end
# TODO: Fix this for ZeroTangents
v = typeof(v) != typeof(a) ? convert(typeof(a), v) : v
#v = convert(typeof(a),v);
backwards_tv = v[args...]
backwards_a = copy(v)
if typeof(backwards_tv) == eltype(a)
backwards_a[args...] = zero(v[args...])
else
backwards_a[args...] = zero.(v[args...])
end
end
return (
NoTangent(), backwards_a, backwards_tv, fill(ZeroTangent(), length(args))...
)
end
return t, _setindex_pullback
end
# TODO: link to Zygote.showgrad once they update documenter.jl
"""
@showtypeofgrad(x)
Macro utility to show to type of the gradient that is about to accumulate for `x`.
See also `Zygote.@showgrad`.
"""
macro showtypeofgrad(x)
return :(
Zygote.hook($(esc(x))) do x̄
println($"∂($x) = ", repr(typeof(x̄)))
x̄
end
)
end