Skip to content

Commit 38f7d2e

Browse files
committed
reimplement otimes
1 parent 9f2c4a8 commit 38f7d2e

File tree

1 file changed

+30
-35
lines changed

1 file changed

+30
-35
lines changed

src/tensors/linalg.jl

Lines changed: 30 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -564,49 +564,44 @@ Compute the tensor product between two `AbstractTensorMap` instances, which resu
564564
new `TensorMap` instance whose codomain is `codomain(t1) ⊗ codomain(t2)` and whose domain
565565
is `domain(t1) ⊗ domain(t2)`.
566566
"""
567-
function (t1::AbstractTensorMap, t2::AbstractTensorMap)
568-
(S = spacetype(t1)) === spacetype(t2) ||
569-
throw(SpaceMismatch("spacetype(t1) ≠ spacetype(t2)"))
570-
cod1, cod2 = codomain(t1), codomain(t2)
571-
dom1, dom2 = domain(t1), domain(t2)
572-
p12 = (
573-
(codomainind(t1)..., (codomainind(t2) .+ numind(t1))...),
574-
(domainind(t1)..., (domainind(t2) .+ numind(t1))...),
567+
function (A::AbstractTensorMap, B::AbstractTensorMap)
568+
(S = spacetype(A)) === spacetype(B) || throw(SpaceMismatch("incompatible space types"))
569+
570+
# allocate destination with correct scalartype
571+
pA = ((codomainind(A)..., domainind(A)...), ())
572+
pB = ((), (codomainind(B)..., domainind(B)...))
573+
NA = numind(A)
574+
pAB = (
575+
(codomainind(A)..., (codomainind(B) .+ NA)...),
576+
(domainind(A)..., (domainind(B) .+ NA)...),
575577
)
576-
577-
T = promote_type(scalartype(t1), scalartype(t2))
578-
TC = promote_type(sectorscalartype(sectortype(t1)), T)
579-
t = TO.tensoralloc_contract(
580-
TC,
581-
t1, ((codomainind(t1)..., domainind(t1)...), ()), false,
582-
t2, ((), (codomainind(t2)..., domainind(t2)...)), false,
583-
p12, Val(false)
584-
)
585-
586-
zerovector!(t)
587-
for (f1l, f1r) in fusiontrees(t1)
588-
for (f2l, f2r) in fusiontrees(t2)
578+
TC = TO.promote_contract(scalartype(A), scalartype(B))
579+
C = TO.tensoralloc_contract(TC, A, pA, false, B, pB, false, pAB, Val(false))
580+
zerovector!(C)
581+
582+
# implement tensor product
583+
for (f1l, f1r) in fusiontrees(A)
584+
@inbounds a = A[f1l, f1r]
585+
for (f2l, f2r) in fusiontrees(B)
586+
@inbounds b = B[f2l, f2r]
589587
c1 = f1l.coupled # = f1r.coupled
590588
c2 = f2l.coupled # = f2r.coupled
591-
for c in c1 c2
592-
for μ in 1:Nsymbol(c1, c2, c)
593-
for (fl, coeff1) in merge(f1l, f2l, c, μ)
594-
for (fr, coeff2) in merge(f1r, f2r, c, μ)
595-
d1 = dim(cod1, f1l.uncoupled)
596-
d2 = dim(cod2, f2l.uncoupled)
597-
d3 = dim(dom1, f1r.uncoupled)
598-
d4 = dim(dom2, f2r.uncoupled)
599-
m1 = sreshape(t1[f1l, f1r], (d1, 1, d3, 1))
600-
m2 = sreshape(t2[f2l, f2r], (1, d2, 1, d4))
601-
m = sreshape(t[fl, fr], (d1, d2, d3, d4))
602-
m .+= coeff1 .* conj(coeff2) .* m1 .* m2
603-
end
589+
for c in c1 c2, μ in 1:Nsymbol(c1, c2, c)
590+
for (fl, coeff1) in merge(f1l, f2l, c, μ)
591+
for (fr, coeff2) in merge(f1r, f2r, c, μ)
592+
TO.tensorcontract!(
593+
C[fl, fr],
594+
A[f1l, f1r], pA, false,
595+
B[f2l, f2r], pB, false,
596+
pAB,
597+
coeff1 * conj(coeff2), One()
598+
)
604599
end
605600
end
606601
end
607602
end
608603
end
609-
return t
604+
return C
610605
end
611606

612607
# deligne product of tensors

0 commit comments

Comments
 (0)