Skip to content

Commit f5dee09

Browse files
authored
Make sample type agnostic and GPU compatible (#93)
1 parent 8896bb1 commit f5dee09

File tree

1 file changed

+9
-7
lines changed

1 file changed

+9
-7
lines changed

src/mps.jl

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -655,6 +655,8 @@ function sample(rng::AbstractRNG, m::MPS)
655655
error("sample: MPS is not normalized, norm=$(norm(m[1]))")
656656
end
657657

658+
ElT = scalartype(m)
659+
658660
result = zeros(Int, N)
659661
A = m[1]
660662

@@ -664,16 +666,16 @@ function sample(rng::AbstractRNG, m::MPS)
664666
# Compute the probability of each state
665667
# one-by-one and stop when the random
666668
# number r is below the total prob so far
667-
pdisc = 0.0
669+
pdisc = zero(real(ElT))
668670
r = rand(rng)
669671
# Will need n,An, and pn below
670672
n = 1
671673
An = ITensor()
672-
pn = 0.0
674+
pn = zero(real(ElT))
673675
while n <= d
674676
projn = ITensor(s)
675-
projn[s => n] = 1.0
676-
An = A * dag(projn)
677+
projn[s => n] = one(ElT)
678+
An = A * dag(adapt(datatype(A), projn))
677679
pn = real(scalar(dag(An) * An))
678680
pdisc += pn
679681
(r < pdisc) && break
@@ -682,7 +684,7 @@ function sample(rng::AbstractRNG, m::MPS)
682684
result[j] = n
683685
if j < N
684686
A = m[j + 1] * An
685-
A *= (1.0 / sqrt(pn))
687+
A *= (one(ElT) / sqrt(pn))
686688
end
687689
end
688690
return result
@@ -749,7 +751,7 @@ function correlation_matrix(
749751
end_site = last(sites)
750752

751753
N = length(psi)
752-
ElT = promote_itensor_eltype(psi)
754+
ElT = scalartype(psi)
753755
s = siteinds(psi)
754756

755757
Op1 = _Op1 #make copies into which we can insert "F" string operators, and then restore.
@@ -983,7 +985,7 @@ updens, dndens = expect(psi, "Nup", "Ndn") # pass more than one operator
983985
function expect(psi::MPS, ops; sites=1:length(psi), site_range=nothing)
984986
psi = copy(psi)
985987
N = length(psi)
986-
ElT = promote_itensor_eltype(psi)
988+
ElT = scalartype(psi)
987989
s = siteinds(psi)
988990

989991
if !isnothing(site_range)

0 commit comments

Comments
 (0)