Skip to content

Commit 175fb5e

Browse files
committed
reorganize constructors
1 parent 88510bc commit 175fb5e

File tree

2 files changed

+183
-147
lines changed

2 files changed

+183
-147
lines changed

src/tensors/abstracttensor.jl

Lines changed: 48 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -609,8 +609,8 @@ function Base.imag(t::AbstractTensorMap)
609609
end
610610
end
611611

612-
# Conversion to Array:
613-
#----------------------
612+
# Conversion to/from Array:
613+
#--------------------------
614614
# probably not optimized for speed, only for checking purposes
615615
function Base.convert(::Type{Array}, t::AbstractTensorMap)
616616
I = sectortype(t)
@@ -631,9 +631,54 @@ function Base.convert(::Type{Array}, t::AbstractTensorMap)
631631
end
632632
end
633633

634+
"""
635+
project_symmetric!(t::AbstractTensorMap, data::AbstractArray) -> t
636+
637+
Project the data from a dense array `data` into the tensor map `t`. This function discards
638+
any data that does not fit the symmetry structure of `t`.
639+
"""
640+
function project_symmetric!(t::AbstractTensorMap, data::AbstractArray)
641+
# dimension check
642+
codom, dom = codomain(t), domain(t)
643+
arraysize = dims(t)
644+
matsize = (dim(codom), dim(dom))
645+
(size(data) == arraysize || size(data) == matsize) ||
646+
throw(DimensionMismatch("input data has incompatible size for the given tensor"))
647+
648+
I = sectortype(t)
649+
if I === Trivial && t isa TensorMap
650+
copy!(t.data, reshape(data, length(t.data)))
651+
return t
652+
end
653+
654+
for ((f₁, f₂), subblock) in subblocks(t)
655+
F = convert(Array, (f₁, f₂))
656+
dataslice = sview(
657+
data, axes(codomain(t), f₁.uncoupled)..., axes(domain(t), f₂.uncoupled)...
658+
)
659+
if FusionStyle(I) === UniqueFusion()
660+
Fscalar = only(F) # contains a single element
661+
scale!(subblock, dataslice, conj(Fscalar))
662+
else
663+
szbF = _interleave(size(F), size(subblock))
664+
indset1 = ntuple(identity, numind(t))
665+
indset2 = 2 .* indset1
666+
indset3 = indset2 .- 1
667+
TensorOperations.tensorcontract!(
668+
subblock,
669+
F, ((), indset1), true,
670+
sreshape(dataslice, szbF), (indset3, indset2), false,
671+
(indset1, ()),
672+
inv(dim(f₁.coupled)), false
673+
)
674+
end
675+
end
676+
677+
return t
678+
end
679+
634680
# Show and friends
635681
# ----------------
636-
637682
function Base.dims2string(V::HomSpace)
638683
str_cod = numout(V) == 0 ? "()" : join(dim.(codomain(V)), '×')
639684
str_dom = numin(V) == 0 ? "()" : join(dim.(domain(V)), '×')

0 commit comments

Comments
 (0)