1
1
@eval module $ (gensym ())
2
2
using Compat: Returns
3
3
using Test: @test , @testset , @test_broken
4
- using BlockArrays: Block, blocksize
4
+ using BlockArrays: Block, blockedrange, blocksize
5
5
using NDTensors. BlockSparseArrays: BlockSparseArray, block_nstored
6
6
using NDTensors. GradedAxes:
7
7
GradedAxes, GradedUnitRange, UnitRangeDual, blocklabels, dual, gradedrange
@@ -73,6 +73,8 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
73
73
# be the real test.
74
74
for ax in axes (m)
75
75
@test ax isa GradedUnitRange
76
+ # TODO : Current `fusedims` doesn't merge
77
+ # common sectors, need to fix.
76
78
@test_broken blocklabels (ax) == [U1 (0 ), U1 (1 ), U1 (2 )]
77
79
@test blocklabels (ax) == [U1 (0 ), U1 (1 ), U1 (1 ), U1 (2 )]
78
80
end
@@ -94,8 +96,13 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
94
96
@testset " dual axes" begin
95
97
r = gradedrange ([U1 (0 ) => 2 , U1 (1 ) => 2 ])
96
98
a = BlockSparseArray {elt} (dual (r), r)
97
- a[Block (1 , 1 )] = randn (elt, size (a[Block (1 , 1 )]))
98
- a[Block (2 , 2 )] = randn (elt, size (a[Block (2 , 2 )]))
99
+ @views for b in [Block (1 , 1 ), Block (2 , 2 )]
100
+ a[b] = randn (elt, size (a[b]))
101
+ end
102
+ # TODO : Define and use `isdual` here.
103
+ @test axes (a, 1 ) isa UnitRangeDual
104
+ @test axes (a, 2 ) isa GradedUnitRange
105
+ @test ! (axes (a, 2 ) isa UnitRangeDual)
99
106
a_dense = Array (a)
100
107
@test eachindex (a) == CartesianIndices (size (a))
101
108
for I in eachindex (a)
@@ -104,8 +111,50 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
104
111
@test axes (a' ) == dual .(reverse (axes (a)))
105
112
# TODO : Define and use `isdual` here.
106
113
@test axes (a' , 1 ) isa UnitRangeDual
114
+ @test axes (a' , 2 ) isa GradedUnitRange
107
115
@test ! (axes (a' , 2 ) isa UnitRangeDual)
108
116
@test isnothing (show (devnull , MIME (" text/plain" ), a))
117
+
118
+ # Check preserving dual in tensor algebra.
119
+ for b in (a + a, 2 * a, 3 * a - a)
120
+ @test Array (b) ≈ 2 * Array (a)
121
+ # TODO : Define and use `isdual` here.
122
+ @test axes (b, 1 ) isa UnitRangeDual
123
+ @test axes (b, 2 ) isa GradedUnitRange
124
+ @test ! (axes (b, 2 ) isa UnitRangeDual)
125
+ end
126
+
127
+ @test isnothing (show (devnull , MIME (" text/plain" ), @view (a[Block (1 , 1 )])))
128
+ @test @view (a[Block (1 , 1 )]) == a[Block (1 , 1 )]
129
+
130
+ # Test case when all axes are dual.
131
+ for r in (gradedrange ([U1 (0 ) => 2 , U1 (1 ) => 2 ]), blockedrange ([2 , 2 ]))
132
+ a = BlockSparseArray {elt} (dual (r), dual (r))
133
+ @views for i in [Block (1 , 1 ), Block (2 , 2 )]
134
+ a[i] = randn (elt, size (a[i]))
135
+ end
136
+ b = 2 * a
137
+ @test block_nstored (b) == 2
138
+ @test Array (b) == 2 * Array (a)
139
+ for ax in axes (b)
140
+ @test ax isa UnitRangeDual
141
+ end
142
+ end
143
+
144
+ # Test case when all axes are dual
145
+ # from taking the adjoint.
146
+ for r in (gradedrange ([U1 (0 ) => 2 , U1 (1 ) => 2 ]), blockedrange ([2 , 2 ]))
147
+ a = BlockSparseArray {elt} (r, r)
148
+ @views for i in [Block (1 , 1 ), Block (2 , 2 )]
149
+ a[i] = randn (elt, size (a[i]))
150
+ end
151
+ b = 2 * a'
152
+ @test block_nstored (b) == 2
153
+ @test Array (b) == 2 * Array (a)'
154
+ for ax in axes (b)
155
+ @test ax isa UnitRangeDual
156
+ end
157
+ end
109
158
end
110
159
@testset " Matrix multiplication" begin
111
160
r = gradedrange ([U1 (0 ) => 2 , U1 (1 ) => 3 ])
0 commit comments