Skip to content

Commit 95bcd6b

Browse files
committed
Add tests
1 parent b663fbf commit 95bcd6b

File tree

1 file changed

+121
-0
lines changed

1 file changed

+121
-0
lines changed

test/test_tensoroperations.jl

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
using Test: @test, @testset, @inferred
2+
using TensorOperations: @tensor, ncon, tensorcontract
3+
using TensorAlgebra: Matricize
4+
5+
@testset "tensorcontract" begin
6+
A = randn(Float64, (3, 20, 5, 3, 4))
7+
B = randn(Float64, (5, 6, 20, 3))
8+
C1 = @inferred tensorcontract(
9+
A, ((1, 4, 5), (2, 3)), false, B, ((3, 1), (2, 4)), false, ((1, 5, 3, 2, 4), ()), 1.0
10+
)
11+
C2 = @inferred tensorcontract(
12+
A,
13+
((1, 4, 5), (2, 3)),
14+
false,
15+
B,
16+
((3, 1), (2, 4)),
17+
false,
18+
((1, 5, 3, 2, 4), ()),
19+
1.0,
20+
Matricize(),
21+
)
22+
@test C1 C2
23+
end
24+
25+
elts = (Float32, Float64, ComplexF32, ComplexF64)
26+
27+
@testset "tensor network examples ($T)" for T in elts
28+
D1, D2, D3 = 30, 40, 20
29+
d1, d2 = 2, 3
30+
A1 = rand(T, D1, d1, D2) .- 1//2
31+
A2 = rand(T, D2, d2, D3) .- 1//2
32+
rhoL = rand(T, D1, D1) .- 1//2
33+
rhoR = rand(T, D3, D3) .- 1//2
34+
H = rand(T, d1, d2, d1, d2) .- 1//2
35+
36+
@tensor HrA12[a, s1, s2, c] :=
37+
rhoL[a, a'] * A1[a', t1, b] * A2[b, t2, c'] * rhoR[c', c] * H[s1, s2, t1, t2]
38+
@tensor backend = Matricize() HrA12′[a, s1, s2, c] :=
39+
rhoL[a, a'] * A1[a', t1, b] * A2[b, t2, c'] * rhoR[c', c] * H[s1, s2, t1, t2]
40+
41+
@test HrA12 HrA12′
42+
@test HrA12 ncon(
43+
[rhoL, H, A2, rhoR, A1],
44+
[[-1, 1], [-2, -3, 4, 5], [2, 5, 3], [3, -4], [1, 4, 2]];
45+
backend=Matricize(),
46+
)
47+
E = @tensor rhoL[a', a] *
48+
A1[a, s, b] *
49+
A2[b, s', c] *
50+
rhoR[c, c'] *
51+
H[t, t', s, s'] *
52+
conj(A1[a', t, b']) *
53+
conj(A2[b', t', c'])
54+
@test E @tensor backend = Matricize() rhoL[a', a] *
55+
A1[a, s, b] *
56+
A2[b, s', c] *
57+
rhoR[c, c'] *
58+
H[t, t', s, s'] *
59+
conj(A1[a', t, b']) *
60+
conj(A2[b', t', c'])
61+
end
62+
63+
function generate_random_network(
64+
num_contracted_inds, num_open_inds, max_dim, max_ind_per_tensor
65+
)
66+
contracted_indices = repeat(collect(1:num_contracted_inds), 2)
67+
open_indices = collect(1:num_open_inds)
68+
dimensions = [
69+
repeat(rand(1:max_dim, num_contracted_inds), 2)
70+
rand(1:max_dim, num_open_inds)
71+
]
72+
73+
sizes = Vector{Int64}[]
74+
indices = Vector{Int64}[]
75+
76+
while !isempty(contracted_indices) || !isempty(open_indices)
77+
num_inds = rand(
78+
1:min(max_ind_per_tensor, length(contracted_indices) + length(open_indices))
79+
)
80+
81+
cur_inds = Int64[]
82+
cur_dims = Int64[]
83+
84+
for _ in 1:num_inds
85+
curind_index = rand(1:(length(contracted_indices) + length(open_indices)))
86+
87+
if curind_index <= length(contracted_indices)
88+
push!(cur_inds, contracted_indices[curind_index])
89+
push!(cur_dims, dimensions[curind_index])
90+
deleteat!(contracted_indices, curind_index)
91+
deleteat!(dimensions, curind_index)
92+
else
93+
tind = curind_index - length(contracted_indices)
94+
push!(cur_inds, -open_indices[tind])
95+
push!(cur_dims, dimensions[curind_index])
96+
deleteat!(open_indices, tind)
97+
deleteat!(dimensions, curind_index)
98+
end
99+
end
100+
101+
push!(sizes, cur_dims)
102+
push!(indices, cur_inds)
103+
end
104+
return sizes, indices
105+
end
106+
107+
@testset "random contractions" begin
108+
MAX_CONTRACTED_INDICES = 10
109+
MAX_OPEN_INDICES = 5
110+
MAX_DIM = 5
111+
MAX_IND_PER_TENS = 3
112+
NUM_TESTS = 10
113+
114+
for _ in 1:NUM_TESTS
115+
sizes, indices = generate_random_network(rand(1:MAX_CONTRACTED_INDICES), rand(1:MAX_OPEN_INDICES), MAX_DIM, MAX_IND_PER_TENS)
116+
tensors = map(splat(randn), sizes)
117+
result1 = ncon(tensors, indices)
118+
result2 = ncon(tensors, indices; backend=Matricize())
119+
@test result1 result2
120+
end
121+
end

0 commit comments

Comments
 (0)