Skip to content

Commit ef42572

Browse files
authored
insertleftunit, insertrightunit and removeunit (#187)
* extend `insertunit` to `HomSpace` * extend `insertunit` to `AbstractTensorMap` * Add `insertunit` tests for `HomSpace` * Add `insertunit` tests for `TensorMap` * Add `removeunit` functionality * Add tests `removeunit` * fixup! Add `insertunit` tests for `TensorMap` * improve type stability * Rewrite in terms of `insertleftunit` and `insertrightunit` * also update docs * fix missing kwargs * update checks in hope of fixing type ambiguity * type stability changes
1 parent c041bfe commit ef42572

File tree

10 files changed

+237
-12
lines changed

10 files changed

+237
-12
lines changed

docs/src/lib/spaces.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,6 @@ fuse
106106
ismonomorphic
107107
isepimorphic
108108
isisomorphic
109-
insertunit
110109
```
111110

112111
There are also specific methods for `HomSpace` instances, that are used in determining
@@ -116,4 +115,6 @@ the resuling `HomSpace` after applying certain tensor operations.
116115
TensorKit.permute(::HomSpace{S}, ::Index2Tuple{N₁,N₂}) where {S,N₁,N₂}
117116
TensorKit.select(::HomSpace{S}, ::Index2Tuple{N₁,N₂}) where {S,N₁,N₂}
118117
TensorKit.compose(::HomSpace{S}, ::HomSpace{S}) where {S}
118+
insertleftunit(::HomSpace, ::Int)
119+
insertrightunit(::HomSpace, ::Int)
119120
```

docs/src/lib/tensors.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,8 @@ braid(::AbstractTensorMap, ::Index2Tuple, ::IndexTuple; ::Bool)
175175
transpose(::AbstractTensorMap, ::Index2Tuple; ::Bool)
176176
repartition(::AbstractTensorMap, ::Int, ::Int; ::Bool)
177177
twist(::AbstractTensorMap, ::Int; ::Bool)
178+
insertleftunit(::AbstractTensorMap, ::Int)
179+
insertrightunit(::AbstractTensorMap, ::Int)
178180
```
179181

180182
```@docs
@@ -224,4 +226,4 @@ and only accept the `TensorMap` object as well as the method-specific algorithm
224226
arguments.
225227

226228

227-
TODO: document svd truncation types
229+
TODO: document svd truncation types

src/TensorKit.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@ export TruncationScheme
3131
export SpaceMismatch, SectorMismatch, IndexError # error types
3232

3333
# general vector space methods
34-
export space, field, dual, dim, reduceddim, dims, fuse, flip, isdual, insertunit, oplus
34+
export space, field, dual, dim, reduceddim, dims, fuse, flip, isdual, oplus,
35+
insertleftunit, insertrightunit, removeunit
3536

3637
# partial order for vector spaces
3738
export infimum, supremum, isisomorphic, ismonomorphic, isepimorphic

src/auxiliary/auxiliary.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,9 @@ function _kron(A, B)
4141
return C
4242
end
4343

44+
@noinline _boundserror(P, i) = throw(BoundsError(P, i))
45+
@noinline _nontrivialspaceerror(P, i) = throw(ArgumentError(lazy"Attempting to remove a non-trivial space $(P[i])"))
46+
4447
# Compat implementation:
4548
@static if VERSION < v"1.7"
4649
macro constprop(setting, ex)

src/auxiliary/deprecate.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,4 +56,6 @@ end
5656

5757
Base.@deprecate EuclideanProduct() EuclideanInnerProduct()
5858

59+
Base.@deprecate insertunit(P::ProductSpace, args...; kwargs...) insertleftunit(args...; kwargs...)
60+
5961
#! format: on

src/spaces/homspace.jl

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,56 @@ function compose(W::HomSpace{S}, V::HomSpace{S}) where {S}
174174
return HomSpace(codomain(W), domain(V))
175175
end
176176

177+
"""
178+
insertleftunit(W::HomSpace, i::Int=numind(W) + 1; conj=false, dual=false)
179+
180+
Insert a trivial vector space, isomorphic to the underlying field, at position `i`.
181+
More specifically, adds a left monoidal unit or its dual.
182+
183+
See also [`insertrightunit`](@ref), [`removeunit`](@ref).
184+
"""
185+
@constprop :aggressive function insertleftunit(W::HomSpace, i::Int=numind(W) + 1;
186+
conj::Bool=false, dual::Bool=false)
187+
if i numout(W)
188+
return insertleftunit(codomain(W), i; conj, dual) domain(W)
189+
else
190+
return codomain(W) insertleftunit(domain(W), i - numout(W); conj, dual)
191+
end
192+
end
193+
194+
"""
195+
insertrightunit(W::HomSpace, i::Int=numind(W); conj=false, dual=false)
196+
197+
Insert a trivial vector space, isomorphic to the underlying field, after position `i`.
198+
More specifically, adds a right monoidal unit or its dual.
199+
200+
See also [`insertleftunit`](@ref), [`removeunit`](@ref).
201+
"""
202+
@constprop :aggressive function insertrightunit(W::HomSpace, i::Int=numind(W);
203+
conj::Bool=false, dual::Bool=false)
204+
if i numout(W)
205+
return insertrightunit(codomain(W), i; conj, dual) domain(W)
206+
else
207+
return codomain(W) insertrightunit(domain(W), i - numout(W); conj, dual)
208+
end
209+
end
210+
211+
"""
212+
removeunit(P::HomSpace, i::Int)
213+
214+
This removes a trivial tensor product factor at position `1 ≤ i ≤ N`.
215+
For this to work, that factor has to be isomorphic to the field of scalars.
216+
217+
This operation undoes the work of [`insertleftunit`](@ref) or [`insertrightunit`](@ref).
218+
"""
219+
@constprop :aggressive function removeunit(P::HomSpace, i::Int)
220+
if i numout(P)
221+
return removeunit(codomain(P), i) domain(P)
222+
else
223+
return codomain(P) removeunit(domain(P), i - numout(P))
224+
end
225+
end
226+
177227
# Block and fusion tree ranges: structure information for building tensors
178228
#--------------------------------------------------------------------------
179229
struct FusionBlockStructure{I,N,F₁,F₂}

src/spaces/productspace.jl

Lines changed: 41 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -246,15 +246,15 @@ fuse(P::ProductSpace{S,0}) where {S<:ElementarySpace} = oneunit(S)
246246
fuse(P::ProductSpace{S}) where {S<:ElementarySpace} = fuse(P.spaces...)
247247

248248
"""
249-
insertunit(P::ProductSpace, i::Int = length(P)+1; dual = false, conj = false)
249+
insertleftunit(P::ProductSpace, i::Int=length(P) + 1; conj=false, dual=false)
250250
251-
For `P::ProductSpace{S,N}`, this adds an extra tensor product factor at position
252-
`1 <= i <= N+1` (last position by default) which is just the `S`-equivalent of the
253-
underlying field of scalars, i.e. `oneunit(S)`. With the keyword arguments, one can choose
254-
to insert the conjugated or dual space instead, which are all isomorphic to the field of
255-
scalars.
251+
Insert a trivial vector space, isomorphic to the underlying field, at position `i`.
252+
More specifically, adds a left monoidal unit or its dual.
253+
254+
See also [`insertrightunit`](@ref), [`removeunit`](@ref).
256255
"""
257-
function insertunit(P::ProductSpace, i::Int=length(P) + 1; dual=false, conj=false)
256+
function insertleftunit(P::ProductSpace, i::Int=length(P) + 1;
257+
conj::Bool=false, dual::Bool=false)
258258
u = oneunit(spacetype(P))
259259
if dual
260260
u = TensorKit.dual(u)
@@ -265,6 +265,40 @@ function insertunit(P::ProductSpace, i::Int=length(P) + 1; dual=false, conj=fals
265265
return ProductSpace(TupleTools.insertafter(P.spaces, i - 1, (u,)))
266266
end
267267

268+
"""
269+
insertrightunit(P::ProductSpace, i::Int=lenght(P); conj=false, dual=false)
270+
271+
Insert a trivial vector space, isomorphic to the underlying field, after position `i`.
272+
More specifically, adds a right monoidal unit or its dual.
273+
274+
See also [`insertleftunit`](@ref), [`removeunit`](@ref).
275+
"""
276+
function insertrightunit(P::ProductSpace, i::Int=length(P);
277+
conj::Bool=false, dual::Bool=false)
278+
u = oneunit(spacetype(P))
279+
if dual
280+
u = TensorKit.dual(u)
281+
end
282+
if conj
283+
u = TensorKit.conj(u)
284+
end
285+
return ProductSpace(TupleTools.insertafter(P.spaces, i, (u,)))
286+
end
287+
288+
"""
289+
removeunit(P::ProductSpace, i::Int)
290+
291+
This removes a trivial tensor product factor at position `1 ≤ i ≤ N`.
292+
For this to work, that factor has to be isomorphic to the field of scalars.
293+
294+
This operation undoes the work of [`insertunit`](@ref).
295+
"""
296+
function removeunit(P::ProductSpace, i::Int)
297+
1 i length(P) || _boundserror(P, i)
298+
isisomorphic(P[i], oneunit(P[i])) || _nontrivialspaceerror(P, i)
299+
return ProductSpace{spacetype(P)}(TupleTools.deleteat(P.spaces, i))
300+
end
301+
268302
# Functionality for extracting and iterating over spaces
269303
#--------------------------------------------------------
270304
Base.length(P::ProductSpace) = length(P.spaces)

src/tensors/indexmanipulations.jl

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,84 @@ See [`twist!`](@ref) for storing the result in place.
292292
"""
293293
twist(t::AbstractTensorMap, i; inv::Bool=false) = twist!(copy(t), i; inv)
294294

295+
"""
296+
insertleftunit(tsrc::AbstractTensorMap, i::Int=numind(t) + 1;
297+
conj=false, dual=false, copy=false) -> tdst
298+
299+
Insert a trivial vector space, isomorphic to the underlying field, at position `i`.
300+
More specifically, adds a left monoidal unit or its dual.
301+
302+
If `copy=false`, `tdst` might share data with `tsrc` whenever possible. Otherwise, a copy is always made.
303+
304+
See also [`insertrightunit`](@ref) and [`removeunit`](@ref).
305+
"""
306+
@constprop :aggressive function insertleftunit(t::AbstractTensorMap,
307+
i::Int=numind(t) + 1; copy::Bool=true,
308+
conj::Bool=false, dual::Bool=false)
309+
W = insertleftunit(space(t), i; conj, dual)
310+
tdst = similar(t, W)
311+
for (c, b) in blocks(t)
312+
copy!(block(tdst, c), b)
313+
end
314+
return tdst
315+
end
316+
@constprop :aggressive function insertleftunit(t::TensorMap, i::Int=numind(t) + 1;
317+
copy::Bool=false,
318+
conj::Bool=false, dual::Bool=false)
319+
W = insertleftunit(space(t), i; conj, dual)
320+
return TensorMap{scalartype(t)}(copy ? Base.copy(t.data) : t.data, W)
321+
end
322+
323+
"""
324+
insertrightunit(tsrc::AbstractTensorMap, i::Int=numind(t);
325+
conj=false, dual=false, copy=false) -> tdst
326+
327+
Insert a trivial vector space, isomorphic to the underlying field, after position `i`.
328+
More specifically, adds a right monoidal unit or its dual.
329+
330+
If `copy=false`, `tdst` might share data with `tsrc` whenever possible. Otherwise, a copy is always made.
331+
332+
See also [`insertleftunit`](@ref) and [`removeunit`](@ref).
333+
"""
334+
@constprop :aggressive function insertrightunit(t::AbstractTensorMap, i::Int=numind(t);
335+
copy::Bool=true, kwargs...)
336+
W = insertrightunit(space(t), i; kwargs...)
337+
tdst = similar(t, W)
338+
for (c, b) in blocks(t)
339+
copy!(block(tdst, c), b)
340+
end
341+
return tdst
342+
end
343+
@constprop :aggressive function insertrightunit(t::TensorMap, i::Int=numind(t);
344+
copy::Bool=false, kwargs...)
345+
W = insertrightunit(space(t), i; kwargs...)
346+
return TensorMap{scalartype(t)}(copy ? Base.copy(t.data) : t.data, W)
347+
end
348+
349+
"""
350+
removeunit(tsrc::AbstractTensorMap, i::Int; copy=false) -> tdst
351+
352+
This removes a trivial tensor product factor at position `1 ≤ i ≤ N`.
353+
For this to work, that factor has to be isomorphic to the field of scalars.
354+
355+
If `copy=false`, `tdst` might share data with `tsrc` whenever possible. Otherwise, a copy is always made.
356+
357+
This operation undoes the work of [`insertunit`](@ref).
358+
"""
359+
@constprop :aggressive function removeunit(t::TensorMap, i::Int; copy::Bool=false)
360+
W = removeunit(space(t), i)
361+
return TensorMap{scalartype(t)}(copy ? Base.copy(t.data) : t.data, W)
362+
end
363+
@constprop :aggressive function removeunit(t::AbstractTensorMap, i::Int;
364+
copy::Bool=true)
365+
W = removeunit(space(t), i)
366+
tdst = similar(t, W)
367+
for (c, b) in blocks(t)
368+
copy!(block(tdst, c), b)
369+
end
370+
return tdst
371+
end
372+
295373
# Fusing and splitting
296374
# TODO: add functionality for easy fusing and splitting of tensor indices
297375

test/spaces.jl

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,10 @@ println("------------------------------------")
278278
@test @constinferred((V1 V2, V3 V4)) == P
279279
@test @constinferred((V1, V2, V3 V4)) == P
280280
@test @constinferred((V1, V2 V3, V4)) == P
281-
@test @constinferred(insertunit(P, 3)) == V1 * V2 * oneunit(V1) * V3 * V4
281+
@test V1 * V2 * oneunit(V1) * V3 * V4 ==
282+
@constinferred(insertleftunit(P, 3)) ==
283+
@constinferred(insertrightunit(P, 2))
284+
@test @constinferred(removeunit(V1 * V2 * oneunit(V1)' * V3 * V4, 3)) == P
282285
@test fuse(V1, V2', V3) V1 V2' V3
283286
@test fuse(V1, V2', V3) V1 V2' V3
284287
@test fuse(V1, V2', V3) V1 V2' V3
@@ -338,7 +341,10 @@ println("------------------------------------")
338341
@test @constinferred(*(V1, V2, V3)) == P
339342
@test @constinferred((V1, V2, V3)) == P
340343
@test @constinferred(adjoint(P)) == dual(P) == V3' V2' V1'
341-
@test @constinferred(insertunit(P, 3; conj=true)) == V1 * V2 * oneunit(V1)' * V3
344+
@test V1 * V2 * oneunit(V1)' * V3 ==
345+
@constinferred(insertleftunit(P, 3; conj=true)) ==
346+
@constinferred(insertrightunit(P, 2; conj=true))
347+
@test P == @constinferred(removeunit(insertleftunit(P, 3), 3))
342348
@test fuse(V1, V2', V3) V1 V2' V3
343349
@test fuse(V1, V2', V3) V1 V2' V3 fuse(V1 V2' V3)
344350
@test fuse(V1, V2') V3 V1 V2' V3
@@ -419,5 +425,21 @@ println("------------------------------------")
419425
@test W == @constinferred permute(W, ((1, 2), (3, 4, 5)))
420426
@test permute(W, ((2, 4, 5), (3, 1))) == (V2 V4' V5' V3 V1')
421427
@test (V1 V2 V1 V2) == @constinferred TensorKit.compose(W, W')
428+
@test (V1 V2 V3 V4 V5 oneunit(V5)) ==
429+
@constinferred(insertleftunit(W)) ==
430+
@constinferred(insertrightunit(W))
431+
@test @constinferred(removeunit(insertleftunit(W), $(numind(W) + 1))) == W
432+
@test (V1 V2 V3 V4 V5 oneunit(V5)') ==
433+
@constinferred(insertleftunit(W; conj=true)) ==
434+
@constinferred(insertrightunit(W; conj=true))
435+
@test (oneunit(V1) V1 V2 V3 V4 V5) ==
436+
@constinferred(insertleftunit(W, 1)) ==
437+
@constinferred(insertrightunit(W, 0))
438+
@test (V1 V2 oneunit(V1) V3 V4 V5) ==
439+
@constinferred(insertrightunit(W, 2))
440+
@test (V1 V2 oneunit(V1) V3 V4 V5) == @constinferred(insertleftunit(W, 3))
441+
@test @constinferred(removeunit(insertleftunit(W, 3), 3)) == W
442+
@test @constinferred(insertrightunit(one(V1) V1, 0)) == (oneunit(V1) V1)
443+
@test_throws BoundsError insertleftunit(one(V1) V1, 0)
422444
end
423445
end

test/tensors.jl

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,38 @@ for V in spacelist
173173
@test w * w' == (w * w')^2
174174
end
175175
end
176+
@timedtestset "Trivial spaces" begin
177+
W = V1 V2 V3 V4 V5
178+
for T in (Float32, ComplexF64)
179+
t = @constinferred rand(T, W)
180+
t2 = @constinferred insertleftunit(t)
181+
@test t2 == @constinferred insertrightunit(t)
182+
@test numind(t2) == numind(t) + 1
183+
@test space(t2) == insertleftunit(space(t))
184+
@test scalartype(t2) === T
185+
@test t.data === t2.data
186+
@test @constinferred(removeunit(t2, $(numind(t2)))) == t
187+
t3 = @constinferred insertleftunit(t; copy=true)
188+
@test t3 == @constinferred insertrightunit(t; copy=true)
189+
@test t.data !== t3.data
190+
for (c, b) in blocks(t)
191+
@test b == block(t3, c)
192+
end
193+
@test @constinferred(removeunit(t3, $(numind(t3)))) == t
194+
t4 = @constinferred insertrightunit(t, 3; dual=true)
195+
@test numin(t4) == numin(t) && numout(t4) == numout(t) + 1
196+
for (c, b) in blocks(t)
197+
@test b == block(t4, c)
198+
end
199+
@test @constinferred(removeunit(t4, 4)) == t
200+
t5 = @constinferred insertleftunit(t, 4; dual=true)
201+
@test numin(t5) == numin(t) + 1 && numout(t5) == numout(t)
202+
for (c, b) in blocks(t)
203+
@test b == block(t5, c)
204+
end
205+
@test @constinferred(removeunit(t5, 4)) == t
206+
end
207+
end
176208
if hasfusiontensor(I)
177209
@timedtestset "Basic linear algebra: test via conversion" begin
178210
W = V1 V2 V3 V4 V5

0 commit comments

Comments
 (0)