Skip to content

Commit 8305383

Browse files
committed
Merge branch 'main' of https://github.com/Jutho/TensorKit.jl into bd/multifusion
2 parents e1ebb2c + feab207 commit 8305383

File tree

5 files changed

+56
-34
lines changed

5 files changed

+56
-34
lines changed

ext/TensorKitChainRulesCoreExt/linalg.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,9 +79,9 @@ function ChainRulesCore.rrule(::typeof(adjoint), A::AbstractTensorMap)
7979
return adjoint(A), adjoint_pullback
8080
end
8181

82-
function ChainRulesCore.rrule(::typeof(twist), A::AbstractTensorMap, is; inv::Bool = false)
83-
tA = twist(A, is; inv)
84-
twist_pullback(ΔA) = NoTangent(), twist(unthunk(ΔA), is; inv = !inv), NoTangent()
82+
function ChainRulesCore.rrule(::typeof(twist), A::AbstractTensorMap, is; inv::Bool = false, kwargs...)
83+
tA = twist(A, is; inv, kwargs...)
84+
twist_pullback(ΔA) = NoTangent(), twist(unthunk(ΔA), is; inv = !inv, kwargs...), NoTangent()
8585
return tA, twist_pullback
8686
end
8787

src/spaces/gradedspace.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ end
200200
Base.summary(io::IO, V::GradedSpace) = print(io, type_repr(typeof(V)))
201201

202202
function Base.show(io::IO, V::GradedSpace)
203-
opn = (get(io, :typeinfo, Any)::DataType == typeof(V) ? "" : type_repr(typeof(V)))
203+
opn = (get(io, :typeinfo, Any)::Type == typeof(V) ? "" : type_repr(typeof(V)))
204204
opn *= "("
205205
if isdual(V)
206206
cls = ")'"

src/tensors/abstracttensor.jl

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -644,20 +644,26 @@ function Base.summary(io::IO, t::AbstractTensorMap)
644644
end
645645

646646
# Human-readable:
647-
function Base.show(io::IO, ::MIME"text/plain", t::AbstractTensorMap)
648-
# 1) show summary: typically d₁×d₂×… ← d₃×d₄×… $(typeof(t)):
647+
function Base.show(io::IO, mime::MIME"text/plain", t::AbstractTensorMap)
648+
# 1) show summary: typically d₁×d₂×… ← d₃×d₄×… $(typeof(t))
649649
summary(io, t)
650-
println(io, ":")
651650

651+
# case without `\n`:
652+
if get(io, :compact, true)
653+
print(io, "(…, ")
654+
show(io, mime, space(t))
655+
print(io, ')')
656+
return nothing
657+
end
658+
659+
# case with `\n`
652660
# 2) show spaces
653-
# println(io, " space(t):")
661+
println(io, ':')
654662
println(io, " codomain: ", codomain(t))
655663
println(io, " domain: ", domain(t))
656664

657665
# 3) [optional]: show data
658-
get(io, :compact, true) && return nothing
659-
ioc = IOContext(io, :typeinfo => sectortype(t))
660666
println(io, "\n\n blocks: ")
661-
show_blocks(io, MIME"text/plain"(), blocks(t))
667+
show_blocks(io, mime, blocks(t))
662668
return nothing
663669
end

src/tensors/indexmanipulations.jl

Lines changed: 35 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -262,47 +262,63 @@ To repartition into an existing destination, see [repartition!](@ref).
262262
end
263263

264264
# Twist
265+
function has_shared_twist(t, inds)
266+
I = sectortype(t)
267+
if BraidingStyle(I) == NoBraiding()
268+
for i in inds
269+
cs = sectors(space(t, i))
270+
all(isunit, cs) || throw(SectorMismatch(lazy"Cannot twist sectors $cs"))
271+
end
272+
return true
273+
elseif BraidingStyle(I) == Bosonic()
274+
return true
275+
else
276+
for i in inds
277+
cs = sectors(space(t, i))
278+
all(isunit twist, cs) || return false
279+
end
280+
return true
281+
end
282+
end
283+
265284
"""
266285
twist!(t::AbstractTensorMap, i::Int; inv::Bool=false) -> t
267-
twist!(t::AbstractTensorMap, is; inv::Bool=false) -> t
286+
twist!(t::AbstractTensorMap, inds; inv::Bool=false) -> t
268287
269-
Apply a twist to the `i`th index of `t`, or all indices in `is`, storing the result in `t`.
288+
Apply a twist to the `i`th index of `t`, or all indices in `inds`, storing the result in `t`.
270289
If `inv=true`, use the inverse twist.
271290
272291
See [`twist`](@ref) for creating a new tensor.
273292
"""
274-
function twist!(t::AbstractTensorMap, is; inv::Bool = false)
275-
if !all(in(allind(t)), is)
276-
msg = "Can't twist indices $is of a tensor with only $(numind(t)) indices."
293+
function twist!(t::AbstractTensorMap, inds; inv::Bool = false)
294+
if !all(in(allind(t)), inds)
295+
msg = "Can't twist indices $inds of a tensor with only $(numind(t)) indices."
277296
throw(ArgumentError(msg))
278297
end
279-
(BraidingStyle(sectortype(t)) == Bosonic() || isempty(is)) && return t
280-
if BraidingStyle(sectortype(t)) == NoBraiding()
281-
for i in is
282-
cs = sectors(space(t, i))
283-
all(isunit, cs) || throw(SectorMismatch(lazy"Cannot twist sectors $cs"))
284-
end
285-
return t
286-
end
298+
has_shared_twist(t, inds) && return t
287299
N₁ = numout(t)
288300
for (f₁, f₂) in fusiontrees(t)
289-
θ = prod(i -> i <= N₁ ? twist(f₁.uncoupled[i]) : twist(f₂.uncoupled[i - N₁]), is)
301+
θ = prod(i -> i <= N₁ ? twist(f₁.uncoupled[i]) : twist(f₂.uncoupled[i - N₁]), inds)
290302
inv &&= θ')
291-
rmul!(t[f₁, f₂], θ)
303+
scale!(t[f₁, f₂], θ)
292304
end
293305
return t
294306
end
295307

296308
"""
297-
twist(tsrc::AbstractTensorMap, i::Int; inv::Bool=false) -> tdst
298-
twist(tsrc::AbstractTensorMap, is; inv::Bool=false) -> tdst
309+
twist(tsrc::AbstractTensorMap, i::Int; inv::Bool = false, copy::Bool = false) -> tdst
310+
twist(tsrc::AbstractTensorMap, inds; inv::Bool = false, copy::Bool = false) -> tdst
299311
300312
Apply a twist to the `i`th index of `tsrc` and return the result as a new tensor.
301-
If `inv=true`, use the inverse twist.
313+
If `inv = true`, use the inverse twist.
314+
If `copy = false`, `tdst` might share data with `tsrc` whenever possible. Otherwise, a copy is always made.
302315
303316
See [`twist!`](@ref) for storing the result in place.
304317
"""
305-
twist(t::AbstractTensorMap, i; inv::Bool = false) = twist!(copy(t), i; inv)
318+
function twist(t::AbstractTensorMap, inds; inv::Bool = false, copy::Bool = false)
319+
!copy && has_shared_twist(t, inds) && return t
320+
return twist!(Base.copy(t), inds; inv)
321+
end
306322

307323
# Methods which change the number of indices, implement using `Val(i)` for type inference
308324
"""

test/symmetries/spaces.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -469,10 +469,10 @@ end
469469

470470
@timedtestset "show and friends" begin
471471
V = U1Space(i => 1 for i in 1:3)
472-
@test string(V) == "Rep[U₁](1 => 1, 2 => 1, 3 => 1)"
473-
@test string(V') == "Rep[U₁](1 => 1, 2 => 1, 3 => 1)'"
474-
@test sprint((x, y) -> show(x, MIME"text/plain"(), y), V) == "Rep[U₁](…) of dim 3:\n 1 => 1\n 2 => 1\n 3 => 1"
475-
@test sprint((x, y) -> show(x, MIME"text/plain"(), y), V') == "Rep[U₁](…)' of dim 3:\n 1 => 1\n 2 => 1\n 3 => 1"
472+
@test string(V) == "$(type_repr(typeof(V)))(1 => 1, 2 => 1, 3 => 1)"
473+
@test string(V') == "$(type_repr(typeof(V)))(1 => 1, 2 => 1, 3 => 1)'"
474+
@test sprint((x, y) -> show(x, MIME"text/plain"(), y), V) == "$(type_repr(typeof(V)))(…) of dim 3:\n 1 => 1\n 2 => 1\n 3 => 1"
475+
@test sprint((x, y) -> show(x, MIME"text/plain"(), y), V') == "$(type_repr(typeof(V)))(…)' of dim 3:\n 1 => 1\n 2 => 1\n 3 => 1"
476476
end
477477

478478
TensorKit.empty_globalcaches!()

0 commit comments

Comments
 (0)