diff --git a/ext/TensorKitChainRulesCoreExt/linalg.jl b/ext/TensorKitChainRulesCoreExt/linalg.jl index af0c7f38..4aa7aa23 100644 --- a/ext/TensorKitChainRulesCoreExt/linalg.jl +++ b/ext/TensorKitChainRulesCoreExt/linalg.jl @@ -79,9 +79,9 @@ function ChainRulesCore.rrule(::typeof(adjoint), A::AbstractTensorMap) return adjoint(A), adjoint_pullback end -function ChainRulesCore.rrule(::typeof(twist), A::AbstractTensorMap, is; inv::Bool = false) - tA = twist(A, is; inv) - twist_pullback(ΔA) = NoTangent(), twist(unthunk(ΔA), is; inv = !inv), NoTangent() +function ChainRulesCore.rrule(::typeof(twist), A::AbstractTensorMap, is; inv::Bool = false, kwargs...) + tA = twist(A, is; inv, kwargs...) + twist_pullback(ΔA) = NoTangent(), twist(unthunk(ΔA), is; inv = !inv, kwargs...), NoTangent() return tA, twist_pullback end diff --git a/src/tensors/indexmanipulations.jl b/src/tensors/indexmanipulations.jl index 7465a7bb..33337a86 100644 --- a/src/tensors/indexmanipulations.jl +++ b/src/tensors/indexmanipulations.jl @@ -262,47 +262,59 @@ To repartition into an existing destination, see [repartition!](@ref). end # Twist +function has_shared_twist(t, inds) + I = sectortype(t) + if BraidingStyle(I) == NoBraiding() + for i in inds + cs = sectors(space(t, i)) + all(isone, cs) || throw(SectorMismatch(lazy"Cannot twist sectors $cs")) + end + return true + elseif BraidingStyle(I) == Bosonic() + return true + else + return isempty(inds) + end +end + """ twist!(t::AbstractTensorMap, i::Int; inv::Bool=false) -> t - twist!(t::AbstractTensorMap, is; inv::Bool=false) -> t + twist!(t::AbstractTensorMap, inds; inv::Bool=false) -> t -Apply a twist to the `i`th index of `t`, or all indices in `is`, storing the result in `t`. +Apply a twist to the `i`th index of `t`, or all indices in `inds`, storing the result in `t`. If `inv=true`, use the inverse twist. See [`twist`](@ref) for creating a new tensor. """ -function twist!(t::AbstractTensorMap, is; inv::Bool = false) - if !all(in(allind(t)), is) - msg = "Can't twist indices $is of a tensor with only $(numind(t)) indices." +function twist!(t::AbstractTensorMap, inds; inv::Bool = false) + if !all(in(allind(t)), inds) + msg = "Can't twist indices $inds of a tensor with only $(numind(t)) indices." throw(ArgumentError(msg)) end - (BraidingStyle(sectortype(t)) == Bosonic() || isempty(is)) && return t - if BraidingStyle(sectortype(t)) == NoBraiding() - for i in is - cs = sectors(space(t, i)) - all(isone, cs) || throw(SectorMismatch(lazy"Cannot twist sectors $cs")) - end - return t - end + has_shared_twist(t, inds) && return t N₁ = numout(t) for (f₁, f₂) in fusiontrees(t) - θ = prod(i -> i <= N₁ ? twist(f₁.uncoupled[i]) : twist(f₂.uncoupled[i - N₁]), is) + θ = prod(i -> i <= N₁ ? twist(f₁.uncoupled[i]) : twist(f₂.uncoupled[i - N₁]), inds) inv && (θ = θ') - rmul!(t[f₁, f₂], θ) + scale!(t[f₁, f₂], θ) end return t end """ - twist(tsrc::AbstractTensorMap, i::Int; inv::Bool=false) -> tdst - twist(tsrc::AbstractTensorMap, is; inv::Bool=false) -> tdst + twist(tsrc::AbstractTensorMap, i::Int; inv::Bool = false, copy::Bool = false) -> tdst + twist(tsrc::AbstractTensorMap, inds; inv::Bool = false, copy::Bool = false) -> tdst Apply a twist to the `i`th index of `tsrc` and return the result as a new tensor. -If `inv=true`, use the inverse twist. +If `inv = true`, use the inverse twist. +If `copy = false`, `tdst` might share data with `tsrc` whenever possible. Otherwise, a copy is always made. See [`twist!`](@ref) for storing the result in place. """ -twist(t::AbstractTensorMap, i; inv::Bool = false) = twist!(copy(t), i; inv) +function twist(t::AbstractTensorMap, inds; inv::Bool = false, copy::Bool = false) + !copy && has_shared_twist(t, inds) && return t + return twist!(copy(t), is; inv) +end # Methods which change the number of indices, implement using `Val(i)` for type inference """