Skip to content

Commit 5eea6e0

Browse files
committed
Added a couple tensor contraction tests.
1 parent f4ba16b commit 5eea6e0

File tree

2 files changed

+66
-0
lines changed

2 files changed

+66
-0
lines changed

test/runtests.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ end
2424

2525
@time include("offsetarrays.jl")
2626

27+
@time include("tensors.jl")
28+
2729
@time include("map.jl")
2830

2931
@time include("filter.jl")

test/tensors.jl

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
2+
3+
function contract!(tiJaB_d_temp3, tiJaB_i, Wmnij)
4+
rvir = axes(tiJaB_d_temp3, 4)
5+
nvir = last(rvir)
6+
rocc = axes(tiJaB_d_temp3, 1)
7+
@inbounds @fastmath for b in rvir, a in b:nvir, j in rocc, i in j:last(rocc)
8+
temp = zero(eltype(tiJaB_i))
9+
for n in rocc, m in rocc
10+
temp += tiJaB_i[m,n,a,b]*Wmnij[m,n,i,j]
11+
end
12+
tiJaB_d_temp3[i,j,a,b] = temp
13+
tiJaB_d_temp3[j,i,b,a] = temp
14+
end
15+
end
16+
17+
function contracttest1!(tiJaB_d_temp3, tiJaB_i, Wmnij)
18+
rvir = axes(tiJaB_d_temp3, 4)
19+
nvir = last(rvir)
20+
rocc = axes(tiJaB_d_temp3, 1)
21+
for b in rvir, j in rocc
22+
@avx for a in b:nvir, i in j:last(rocc)
23+
temp = zero(eltype(tiJaB_i))
24+
for n in rocc, m in rocc
25+
temp += tiJaB_i[m,n,a,b]*Wmnij[m,n,i,j]
26+
end
27+
tiJaB_d_temp3[i,j,a,b] = temp
28+
tiJaB_d_temp3[j,i,b,a] = temp
29+
end
30+
end
31+
end
32+
function contracttest2!(tiJaB_d_temp3, tiJaB_i, Wmnij)
33+
rvir = axes(tiJaB_d_temp3, 4)
34+
nvir = last(rvir)
35+
rocc = axes(tiJaB_d_temp3, 1)
36+
for b in rvir, a in b:nvir, j in rocc
37+
@avx for i in j:last(rocc)
38+
temp = zero(eltype(tiJaB_i))
39+
for n in rocc, m in rocc
40+
temp += tiJaB_i[m,n,a,b]*Wmnij[m,n,i,j]
41+
end
42+
tiJaB_d_temp3[i,j,a,b] = temp
43+
tiJaB_d_temp3[j,i,b,a] = temp
44+
end
45+
end
46+
end
47+
48+
@testset "Tensors" begin
49+
LA, LIM = 31, 23;
50+
A = rand(LIM, LIM, LA, LA);
51+
B = rand(LIM, LIM, LIM, LIM);
52+
53+
C1 = Array{Float64}(undef, LIM, LIM, LA, LA);
54+
C2 = similar(C1); C3 = similar(C1);
55+
56+
@time contract!(C1, A, B)
57+
@time contracttest1!(C2, A, B)
58+
@time contracttest2!(C3, A, B)
59+
60+
@test C1 C2
61+
@test C1 C3
62+
end
63+
64+

0 commit comments

Comments
 (0)