Skip to content

Commit 1a6d224

Browse files
authored
LazyNamedDimsArrays (#9)
1 parent 88a6203 commit 1a6d224

File tree

5 files changed

+249
-2
lines changed

5 files changed

+249
-2
lines changed

Project.toml

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ITensorNetworksNext"
22
uuid = "302f2e75-49f0-4526-aef7-d8ba550cb06c"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.1.5"
4+
version = "0.1.6"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
@@ -15,9 +15,11 @@ NamedDimsArrays = "60cbd0c0-df58-4cb7-918c-6f5607b73fde"
1515
NamedGraphs = "678767b0-92e7-4007-89e4-4527a8725b19"
1616
SimpleTraits = "699a6c99-e7fa-54fc-8d76-47d257e15c1d"
1717
SplitApplyCombine = "03a91e81-4c3e-53e1-a0a4-9c0c8f19dd66"
18+
TermInterface = "8ea1fca8-c5ef-4a55-8b96-4e9afe9c9a3c"
19+
WrappedUnions = "325db55a-9c6c-5b90-b1a2-ec87e7a38c44"
1820

1921
[compat]
20-
Adapt = "4.3.0"
22+
Adapt = "4.3"
2123
BackendSelection = "0.1.6"
2224
DataGraphs = "0.2.7"
2325
Dictionaries = "0.4.5"
@@ -28,4 +30,6 @@ NamedDimsArrays = "0.8"
2830
NamedGraphs = "0.6.9, 0.7"
2931
SimpleTraits = "0.9.5"
3032
SplitApplyCombine = "1.2.3"
33+
TermInterface = "2"
34+
WrappedUnions = "0.3"
3135
julia = "1.10"

src/ITensorNetworksNext.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
module ITensorNetworksNext
22

3+
include("lazynameddimsarrays.jl")
34
include("abstracttensornetwork.jl")
45
include("tensornetwork.jl")
56

src/lazynameddimsarrays.jl

Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
module LazyNamedDimsArrays
2+
3+
using WrappedUnions: @wrapped, unwrap
4+
using NamedDimsArrays:
5+
NamedDimsArrays,
6+
AbstractNamedDimsArray,
7+
AbstractNamedDimsArrayStyle,
8+
dename,
9+
inds
10+
11+
struct Prod{A}
12+
factors::Vector{A}
13+
end
14+
15+
@wrapped struct LazyNamedDimsArray{
16+
T, A <: AbstractNamedDimsArray{T},
17+
} <: AbstractNamedDimsArray{T, Any}
18+
union::Union{A, Prod{LazyNamedDimsArray{T, A}}}
19+
end
20+
21+
function NamedDimsArrays.inds(a::LazyNamedDimsArray)
22+
if unwrap(a) isa AbstractNamedDimsArray
23+
return inds(unwrap(a))
24+
elseif unwrap(a) isa Prod
25+
return mapreduce(inds, symdiff, unwrap(a).factors)
26+
else
27+
return error("Variant not supported.")
28+
end
29+
end
30+
function NamedDimsArrays.dename(a::LazyNamedDimsArray)
31+
if unwrap(a) isa AbstractNamedDimsArray
32+
return dename(unwrap(a))
33+
elseif unwrap(a) isa Prod
34+
return dename(materialize(a), inds(a))
35+
else
36+
return error("Variant not supported.")
37+
end
38+
end
39+
40+
using Base.Broadcast: materialize
41+
function Base.Broadcast.materialize(a::LazyNamedDimsArray)
42+
if unwrap(a) isa AbstractNamedDimsArray
43+
return unwrap(a)
44+
elseif unwrap(a) isa Prod
45+
return prod(materialize, unwrap(a).factors)
46+
else
47+
return error("Variant not supported.")
48+
end
49+
end
50+
Base.copy(a::LazyNamedDimsArray) = materialize(a)
51+
52+
function Base.:*(a::LazyNamedDimsArray)
53+
if unwrap(a) isa AbstractNamedDimsArray
54+
return LazyNamedDimsArray(Prod([lazy(unwrap(a))]))
55+
elseif unwrap(a) isa Prod
56+
return a
57+
else
58+
return error("Variant not supported.")
59+
end
60+
end
61+
62+
function Base.:*(a1::LazyNamedDimsArray, a2::LazyNamedDimsArray)
63+
# Nested by default.
64+
return LazyNamedDimsArray(Prod([a1, a2]))
65+
end
66+
function Base.:+(a1::LazyNamedDimsArray, a2::LazyNamedDimsArray)
67+
return error("Not implemented.")
68+
end
69+
function Base.:-(a1::LazyNamedDimsArray, a2::LazyNamedDimsArray)
70+
return error("Not implemented.")
71+
end
72+
function Base.:*(c::Number, a::LazyNamedDimsArray)
73+
return error("Not implemented.")
74+
end
75+
function Base.:*(a::LazyNamedDimsArray, c::Number)
76+
return error("Not implemented.")
77+
end
78+
function Base.:/(a::LazyNamedDimsArray, c::Number)
79+
return error("Not implemented.")
80+
end
81+
function Base.:-(a::LazyNamedDimsArray)
82+
return error("Not implemented.")
83+
end
84+
85+
function LazyNamedDimsArray(a::AbstractNamedDimsArray)
86+
return LazyNamedDimsArray{eltype(a), typeof(a)}(a)
87+
end
88+
function LazyNamedDimsArray(a::Prod{LazyNamedDimsArray{T, A}}) where {T, A}
89+
return LazyNamedDimsArray{T, A}(a)
90+
end
91+
function lazy(a::AbstractNamedDimsArray)
92+
return LazyNamedDimsArray(a)
93+
end
94+
95+
# Broadcasting
96+
struct LazyNamedDimsArrayStyle <: AbstractNamedDimsArrayStyle{Any} end
97+
function Base.BroadcastStyle(::Type{<:LazyNamedDimsArray})
98+
return LazyNamedDimsArrayStyle()
99+
end
100+
function Broadcast.broadcasted(::LazyNamedDimsArrayStyle, f, as...)
101+
return error("Arbitrary broadcasting not supported for LazyNamedDimsArray.")
102+
end
103+
# Linear operations.
104+
function Broadcast.broadcasted(::LazyNamedDimsArrayStyle, ::typeof(+), a1, a2)
105+
return a1 + a2
106+
end
107+
function Broadcast.broadcasted(::LazyNamedDimsArrayStyle, ::typeof(-), a1, a2)
108+
return a1 - a2
109+
end
110+
function Broadcast.broadcasted(::LazyNamedDimsArrayStyle, ::typeof(*), c::Number, a)
111+
return c * a
112+
end
113+
function Broadcast.broadcasted(::LazyNamedDimsArrayStyle, ::typeof(*), a, c::Number)
114+
return a * c
115+
end
116+
# Fix ambiguity error.
117+
function Broadcast.broadcasted(::LazyNamedDimsArrayStyle, ::typeof(*), a::Number, b::Number)
118+
return a * b
119+
end
120+
function Broadcast.broadcasted(::LazyNamedDimsArrayStyle, ::typeof(/), a, c::Number)
121+
return a / c
122+
end
123+
function Broadcast.broadcasted(::LazyNamedDimsArrayStyle, ::typeof(-), a)
124+
return -a
125+
end
126+
127+
using TermInterface: TermInterface
128+
# arguments, arity, children, head, iscall, operation
129+
function TermInterface.arguments(a::LazyNamedDimsArray)
130+
if unwrap(a) isa AbstractNamedDimsArray
131+
return error("No arguments.")
132+
elseif unwrap(a) isa Prod
133+
unwrap(a).factors
134+
else
135+
return error("Variant not supported.")
136+
end
137+
end
138+
function TermInterface.children(a::LazyNamedDimsArray)
139+
return TermInterface.arguments(a)
140+
end
141+
function TermInterface.head(a::LazyNamedDimsArray)
142+
return TermInterface.operation(a)
143+
end
144+
function TermInterface.iscall(a::LazyNamedDimsArray)
145+
if unwrap(a) isa AbstractNamedDimsArray
146+
return false
147+
elseif unwrap(a) isa Prod
148+
return true
149+
else
150+
return false
151+
end
152+
end
153+
function TermInterface.isexpr(a::LazyNamedDimsArray)
154+
return TermInterface.iscall(a)
155+
end
156+
function TermInterface.maketerm(::Type{LazyNamedDimsArray}, head, args, metadata)
157+
if head prod
158+
return LazyNamedDimsArray(Prod(args))
159+
else
160+
return error("Only product terms supported right now.")
161+
end
162+
end
163+
function TermInterface.operation(a::LazyNamedDimsArray)
164+
if unwrap(a) isa AbstractNamedDimsArray
165+
return error("No operation.")
166+
elseif unwrap(a) isa Prod
167+
prod
168+
else
169+
return error("Variant not supported.")
170+
end
171+
end
172+
function TermInterface.sorted_arguments(a::LazyNamedDimsArray)
173+
if unwrap(a) isa AbstractNamedDimsArray
174+
return error("No arguments.")
175+
elseif unwrap(a) isa Prod
176+
return TermInterface.arguments(a)
177+
else
178+
return error("Variant not supported.")
179+
end
180+
end
181+
182+
end

test/Project.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@ NamedDimsArrays = "60cbd0c0-df58-4cb7-918c-6f5607b73fde"
88
NamedGraphs = "678767b0-92e7-4007-89e4-4527a8725b19"
99
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
1010
Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb"
11+
TermInterface = "8ea1fca8-c5ef-4a55-8b96-4e9afe9c9a3c"
1112
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
13+
WrappedUnions = "325db55a-9c6c-5b90-b1a2-ec87e7a38c44"
1214

1315
[compat]
1416
Aqua = "0.8.14"
@@ -20,4 +22,6 @@ NamedDimsArrays = "0.8"
2022
NamedGraphs = "0.6.8, 0.7"
2123
SafeTestsets = "0.1"
2224
Suppressor = "0.2.8"
25+
TermInterface = "2"
2326
Test = "1.10"
27+
WrappedUnions = "0.3"

test/test_lazynameddimsarrays.jl

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
using Base.Broadcast: materialize
2+
using ITensorNetworksNext.LazyNamedDimsArrays: LazyNamedDimsArray, Prod, lazy
3+
using NamedDimsArrays: NamedDimsArray, inds, nameddims
4+
using TermInterface:
5+
arguments, arity, children, head, iscall, isexpr, maketerm, operation, sorted_arguments
6+
using Test: @test, @test_throws, @testset
7+
using WrappedUnions: unwrap
8+
9+
@testset "LazyNamedDimsArrays" begin
10+
@testset "Basics" begin
11+
a1 = nameddims(randn(2, 2), (:i, :j))
12+
a2 = nameddims(randn(2, 2), (:j, :k))
13+
a3 = nameddims(randn(2, 2), (:k, :l))
14+
l1, l2, l3 = lazy.((a1, a2, a3))
15+
for li in (l1, l2, l3)
16+
@test li isa LazyNamedDimsArray
17+
@test unwrap(li) isa NamedDimsArray
18+
@test inds(li) == inds(unwrap(li))
19+
@test copy(li) == unwrap(li)
20+
@test materialize(li) == unwrap(li)
21+
end
22+
l = l1 * l2 * l3
23+
@test copy(l) a1 * a2 * a3
24+
@test materialize(l) a1 * a2 * a3
25+
@test issetequal(inds(l), symdiff(inds.((a1, a2, a3))...))
26+
@test unwrap(l) isa Prod
27+
@test unwrap(l).factors == [l1 * l2, l3]
28+
end
29+
30+
@testset "TermInterface" begin
31+
a1 = nameddims(randn(2, 2), (:i, :j))
32+
a2 = nameddims(randn(2, 2), (:j, :k))
33+
a3 = nameddims(randn(2, 2), (:k, :l))
34+
l1, l2, l3 = lazy.((a1, a2, a3))
35+
36+
@test_throws ErrorException arguments(l1)
37+
@test_throws ErrorException arity(l1)
38+
@test_throws ErrorException children(l1)
39+
@test_throws ErrorException head(l1)
40+
@test !iscall(l1)
41+
@test !isexpr(l1)
42+
@test_throws ErrorException operation(l1)
43+
@test_throws ErrorException sorted_arguments(l1)
44+
45+
l = l1 * l2 * l3
46+
@test arguments(l) == [l1 * l2, l3]
47+
@test arity(l) == 2
48+
@test children(l) == [l1 * l2, l3]
49+
@test head(l) prod
50+
@test iscall(l)
51+
@test isexpr(l)
52+
@test l == maketerm(LazyNamedDimsArray, prod, [l1 * l2, l3], nothing)
53+
@test operation(l) prod
54+
@test sorted_arguments(l) == [l1 * l2, l3]
55+
end
56+
end

0 commit comments

Comments
 (0)