Skip to content

Commit feab207

Browse files
leburgellkdvos
andauthored
Avoid copy in twist for tensors with bosonic braiding (#305)
* Avoid copy in `twist` for tensors with bosonic braiding * Fix typo * add `copy` keyword argument * handle kwargs in `twist` rrule * share more code * add code suggestion * fix name alias * fix variable rename --------- Co-authored-by: Lukas Devos <[email protected]>
1 parent db09a2f commit feab207

File tree

2 files changed

+38
-22
lines changed

2 files changed

+38
-22
lines changed

ext/TensorKitChainRulesCoreExt/linalg.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,9 +79,9 @@ function ChainRulesCore.rrule(::typeof(adjoint), A::AbstractTensorMap)
7979
return adjoint(A), adjoint_pullback
8080
end
8181

82-
function ChainRulesCore.rrule(::typeof(twist), A::AbstractTensorMap, is; inv::Bool = false)
83-
tA = twist(A, is; inv)
84-
twist_pullback(ΔA) = NoTangent(), twist(unthunk(ΔA), is; inv = !inv), NoTangent()
82+
function ChainRulesCore.rrule(::typeof(twist), A::AbstractTensorMap, is; inv::Bool = false, kwargs...)
83+
tA = twist(A, is; inv, kwargs...)
84+
twist_pullback(ΔA) = NoTangent(), twist(unthunk(ΔA), is; inv = !inv, kwargs...), NoTangent()
8585
return tA, twist_pullback
8686
end
8787

src/tensors/indexmanipulations.jl

Lines changed: 35 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -262,47 +262,63 @@ To repartition into an existing destination, see [repartition!](@ref).
262262
end
263263

264264
# Twist
265+
function has_shared_twist(t, inds)
266+
I = sectortype(t)
267+
if BraidingStyle(I) == NoBraiding()
268+
for i in inds
269+
cs = sectors(space(t, i))
270+
all(isone, cs) || throw(SectorMismatch(lazy"Cannot twist sectors $cs"))
271+
end
272+
return true
273+
elseif BraidingStyle(I) == Bosonic()
274+
return true
275+
else
276+
for i in inds
277+
cs = sectors(space(t, i))
278+
all(isone twist, cs) || return false
279+
end
280+
return true
281+
end
282+
end
283+
265284
"""
266285
twist!(t::AbstractTensorMap, i::Int; inv::Bool=false) -> t
267-
twist!(t::AbstractTensorMap, is; inv::Bool=false) -> t
286+
twist!(t::AbstractTensorMap, inds; inv::Bool=false) -> t
268287
269-
Apply a twist to the `i`th index of `t`, or all indices in `is`, storing the result in `t`.
288+
Apply a twist to the `i`th index of `t`, or all indices in `inds`, storing the result in `t`.
270289
If `inv=true`, use the inverse twist.
271290
272291
See [`twist`](@ref) for creating a new tensor.
273292
"""
274-
function twist!(t::AbstractTensorMap, is; inv::Bool = false)
275-
if !all(in(allind(t)), is)
276-
msg = "Can't twist indices $is of a tensor with only $(numind(t)) indices."
293+
function twist!(t::AbstractTensorMap, inds; inv::Bool = false)
294+
if !all(in(allind(t)), inds)
295+
msg = "Can't twist indices $inds of a tensor with only $(numind(t)) indices."
277296
throw(ArgumentError(msg))
278297
end
279-
(BraidingStyle(sectortype(t)) == Bosonic() || isempty(is)) && return t
280-
if BraidingStyle(sectortype(t)) == NoBraiding()
281-
for i in is
282-
cs = sectors(space(t, i))
283-
all(isone, cs) || throw(SectorMismatch(lazy"Cannot twist sectors $cs"))
284-
end
285-
return t
286-
end
298+
has_shared_twist(t, inds) && return t
287299
N₁ = numout(t)
288300
for (f₁, f₂) in fusiontrees(t)
289-
θ = prod(i -> i <= N₁ ? twist(f₁.uncoupled[i]) : twist(f₂.uncoupled[i - N₁]), is)
301+
θ = prod(i -> i <= N₁ ? twist(f₁.uncoupled[i]) : twist(f₂.uncoupled[i - N₁]), inds)
290302
inv &&= θ')
291-
rmul!(t[f₁, f₂], θ)
303+
scale!(t[f₁, f₂], θ)
292304
end
293305
return t
294306
end
295307

296308
"""
297-
twist(tsrc::AbstractTensorMap, i::Int; inv::Bool=false) -> tdst
298-
twist(tsrc::AbstractTensorMap, is; inv::Bool=false) -> tdst
309+
twist(tsrc::AbstractTensorMap, i::Int; inv::Bool = false, copy::Bool = false) -> tdst
310+
twist(tsrc::AbstractTensorMap, inds; inv::Bool = false, copy::Bool = false) -> tdst
299311
300312
Apply a twist to the `i`th index of `tsrc` and return the result as a new tensor.
301-
If `inv=true`, use the inverse twist.
313+
If `inv = true`, use the inverse twist.
314+
If `copy = false`, `tdst` might share data with `tsrc` whenever possible. Otherwise, a copy is always made.
302315
303316
See [`twist!`](@ref) for storing the result in place.
304317
"""
305-
twist(t::AbstractTensorMap, i; inv::Bool = false) = twist!(copy(t), i; inv)
318+
function twist(t::AbstractTensorMap, inds; inv::Bool = false, copy::Bool = false)
319+
!copy && has_shared_twist(t, inds) && return t
320+
return twist!(Base.copy(t), inds; inv)
321+
end
306322

307323
# Methods which change the number of indices, implement using `Val(i)` for type inference
308324
"""

0 commit comments

Comments
 (0)