|
| 1 | +using Test: @test, @testset, @inferred |
| 2 | +using TensorOperations: @tensor, ncon, tensorcontract |
| 3 | +using TensorAlgebra: Matricize |
| 4 | + |
| 5 | +@testset "tensorcontract" begin |
| 6 | + A = randn(Float64, (3, 20, 5, 3, 4)) |
| 7 | + B = randn(Float64, (5, 6, 20, 3)) |
| 8 | + C1 = @inferred tensorcontract( |
| 9 | + A, ((1, 4, 5), (2, 3)), false, B, ((3, 1), (2, 4)), false, ((1, 5, 3, 2, 4), ()), 1.0 |
| 10 | + ) |
| 11 | + C2 = @inferred tensorcontract( |
| 12 | + A, |
| 13 | + ((1, 4, 5), (2, 3)), |
| 14 | + false, |
| 15 | + B, |
| 16 | + ((3, 1), (2, 4)), |
| 17 | + false, |
| 18 | + ((1, 5, 3, 2, 4), ()), |
| 19 | + 1.0, |
| 20 | + Matricize(), |
| 21 | + ) |
| 22 | + @test C1 ≈ C2 |
| 23 | +end |
| 24 | + |
| 25 | +elts = (Float32, Float64, ComplexF32, ComplexF64) |
| 26 | + |
| 27 | +@testset "tensor network examples ($T)" for T in elts |
| 28 | + D1, D2, D3 = 30, 40, 20 |
| 29 | + d1, d2 = 2, 3 |
| 30 | + A1 = rand(T, D1, d1, D2) .- 1//2 |
| 31 | + A2 = rand(T, D2, d2, D3) .- 1//2 |
| 32 | + rhoL = rand(T, D1, D1) .- 1//2 |
| 33 | + rhoR = rand(T, D3, D3) .- 1//2 |
| 34 | + H = rand(T, d1, d2, d1, d2) .- 1//2 |
| 35 | + |
| 36 | + @tensor HrA12[a, s1, s2, c] := |
| 37 | + rhoL[a, a'] * A1[a', t1, b] * A2[b, t2, c'] * rhoR[c', c] * H[s1, s2, t1, t2] |
| 38 | + @tensor backend = Matricize() HrA12′[a, s1, s2, c] := |
| 39 | + rhoL[a, a'] * A1[a', t1, b] * A2[b, t2, c'] * rhoR[c', c] * H[s1, s2, t1, t2] |
| 40 | + |
| 41 | + @test HrA12 ≈ HrA12′ |
| 42 | + @test HrA12 ≈ ncon( |
| 43 | + [rhoL, H, A2, rhoR, A1], |
| 44 | + [[-1, 1], [-2, -3, 4, 5], [2, 5, 3], [3, -4], [1, 4, 2]]; |
| 45 | + backend=Matricize(), |
| 46 | + ) |
| 47 | + E = @tensor rhoL[a', a] * |
| 48 | + A1[a, s, b] * |
| 49 | + A2[b, s', c] * |
| 50 | + rhoR[c, c'] * |
| 51 | + H[t, t', s, s'] * |
| 52 | + conj(A1[a', t, b']) * |
| 53 | + conj(A2[b', t', c']) |
| 54 | + @test E ≈ @tensor backend = Matricize() rhoL[a', a] * |
| 55 | + A1[a, s, b] * |
| 56 | + A2[b, s', c] * |
| 57 | + rhoR[c, c'] * |
| 58 | + H[t, t', s, s'] * |
| 59 | + conj(A1[a', t, b']) * |
| 60 | + conj(A2[b', t', c']) |
| 61 | +end |
| 62 | + |
| 63 | +function generate_random_network( |
| 64 | + num_contracted_inds, num_open_inds, max_dim, max_ind_per_tensor |
| 65 | +) |
| 66 | + contracted_indices = repeat(collect(1:num_contracted_inds), 2) |
| 67 | + open_indices = collect(1:num_open_inds) |
| 68 | + dimensions = [ |
| 69 | + repeat(rand(1:max_dim, num_contracted_inds), 2) |
| 70 | + rand(1:max_dim, num_open_inds) |
| 71 | + ] |
| 72 | + |
| 73 | + sizes = Vector{Int64}[] |
| 74 | + indices = Vector{Int64}[] |
| 75 | + |
| 76 | + while !isempty(contracted_indices) || !isempty(open_indices) |
| 77 | + num_inds = rand( |
| 78 | + 1:min(max_ind_per_tensor, length(contracted_indices) + length(open_indices)) |
| 79 | + ) |
| 80 | + |
| 81 | + cur_inds = Int64[] |
| 82 | + cur_dims = Int64[] |
| 83 | + |
| 84 | + for _ in 1:num_inds |
| 85 | + curind_index = rand(1:(length(contracted_indices) + length(open_indices))) |
| 86 | + |
| 87 | + if curind_index <= length(contracted_indices) |
| 88 | + push!(cur_inds, contracted_indices[curind_index]) |
| 89 | + push!(cur_dims, dimensions[curind_index]) |
| 90 | + deleteat!(contracted_indices, curind_index) |
| 91 | + deleteat!(dimensions, curind_index) |
| 92 | + else |
| 93 | + tind = curind_index - length(contracted_indices) |
| 94 | + push!(cur_inds, -open_indices[tind]) |
| 95 | + push!(cur_dims, dimensions[curind_index]) |
| 96 | + deleteat!(open_indices, tind) |
| 97 | + deleteat!(dimensions, curind_index) |
| 98 | + end |
| 99 | + end |
| 100 | + |
| 101 | + push!(sizes, cur_dims) |
| 102 | + push!(indices, cur_inds) |
| 103 | + end |
| 104 | + return sizes, indices |
| 105 | +end |
| 106 | + |
| 107 | +@testset "random contractions" begin |
| 108 | + MAX_CONTRACTED_INDICES = 10 |
| 109 | + MAX_OPEN_INDICES = 5 |
| 110 | + MAX_DIM = 5 |
| 111 | + MAX_IND_PER_TENS = 3 |
| 112 | + NUM_TESTS = 10 |
| 113 | + |
| 114 | + for _ in 1:NUM_TESTS |
| 115 | + sizes, indices = generate_random_network(rand(1:MAX_CONTRACTED_INDICES), rand(1:MAX_OPEN_INDICES), MAX_DIM, MAX_IND_PER_TENS) |
| 116 | + tensors = map(splat(randn), sizes) |
| 117 | + result1 = ncon(tensors, indices) |
| 118 | + result2 = ncon(tensors, indices; backend=Matricize()) |
| 119 | + @test result1 ≈ result2 |
| 120 | + end |
| 121 | +end |
0 commit comments