Skip to content

Commit 8a6158e

Browse files
authored
Change Prod to Mul (#10)
1 parent 1a6d224 commit 8a6158e

File tree

3 files changed

+109
-83
lines changed

3 files changed

+109
-83
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ITensorNetworksNext"
22
uuid = "302f2e75-49f0-4526-aef7-d8ba550cb06c"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.1.6"
4+
version = "0.1.7"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/lazynameddimsarrays.jl

Lines changed: 87 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -7,52 +7,119 @@ using NamedDimsArrays:
77
AbstractNamedDimsArrayStyle,
88
dename,
99
inds
10+
using TermInterface: TermInterface, arguments, iscall, maketerm, operation, sorted_arguments
1011

11-
struct Prod{A}
12-
factors::Vector{A}
13-
end
12+
struct Mul{A}
13+
arguments::Vector{A}
14+
end
15+
TermInterface.arguments(m::Mul) = getfield(m, :arguments)
16+
TermInterface.children(m::Mul) = arguments(m)
17+
TermInterface.head(m::Mul) = operation(m)
18+
TermInterface.iscall(m::Mul) = true
19+
TermInterface.isexpr(m::Mul) = iscall(m)
20+
TermInterface.maketerm(::Type{Mul}, head::typeof(*), args, metadata) = Mul(args)
21+
TermInterface.operation(m::Mul) = *
22+
TermInterface.sorted_arguments(m::Mul) = arguments(m)
23+
TermInterface.sorted_children(m::Mul) = sorted_arguments(a)
1424

1525
@wrapped struct LazyNamedDimsArray{
1626
T, A <: AbstractNamedDimsArray{T},
1727
} <: AbstractNamedDimsArray{T, Any}
18-
union::Union{A, Prod{LazyNamedDimsArray{T, A}}}
28+
union::Union{A, Mul{LazyNamedDimsArray{T, A}}}
1929
end
2030

2131
function NamedDimsArrays.inds(a::LazyNamedDimsArray)
22-
if unwrap(a) isa AbstractNamedDimsArray
23-
return inds(unwrap(a))
24-
elseif unwrap(a) isa Prod
25-
return mapreduce(inds, symdiff, unwrap(a).factors)
32+
u = unwrap(a)
33+
if u isa AbstractNamedDimsArray
34+
return inds(u)
35+
elseif u isa Mul
36+
return mapreduce(inds, symdiff, arguments(u))
2637
else
2738
return error("Variant not supported.")
2839
end
2940
end
3041
function NamedDimsArrays.dename(a::LazyNamedDimsArray)
31-
if unwrap(a) isa AbstractNamedDimsArray
32-
return dename(unwrap(a))
33-
elseif unwrap(a) isa Prod
42+
u = unwrap(a)
43+
if u isa AbstractNamedDimsArray
44+
return dename(u)
45+
elseif u isa Mul
3446
return dename(materialize(a), inds(a))
3547
else
3648
return error("Variant not supported.")
3749
end
3850
end
3951

52+
function TermInterface.arguments(a::LazyNamedDimsArray)
53+
u = unwrap(a)
54+
if u isa AbstractNamedDimsArray
55+
return error("No arguments.")
56+
elseif u isa Mul
57+
return arguments(u)
58+
else
59+
return error("Variant not supported.")
60+
end
61+
end
62+
function TermInterface.children(a::LazyNamedDimsArray)
63+
return arguments(a)
64+
end
65+
function TermInterface.head(a::LazyNamedDimsArray)
66+
return operation(a)
67+
end
68+
function TermInterface.iscall(a::LazyNamedDimsArray)
69+
return iscall(unwrap(a))
70+
end
71+
function TermInterface.isexpr(a::LazyNamedDimsArray)
72+
return iscall(a)
73+
end
74+
function TermInterface.maketerm(::Type{LazyNamedDimsArray}, head, args, metadata)
75+
if head *
76+
return LazyNamedDimsArray(maketerm(Mul, head, args, metadata))
77+
else
78+
return error("Only product terms supported right now.")
79+
end
80+
end
81+
function TermInterface.operation(a::LazyNamedDimsArray)
82+
u = unwrap(a)
83+
if u isa AbstractNamedDimsArray
84+
return error("No operation.")
85+
elseif u isa Mul
86+
return operation(u)
87+
else
88+
return error("Variant not supported.")
89+
end
90+
end
91+
function TermInterface.sorted_arguments(a::LazyNamedDimsArray)
92+
u = unwrap(a)
93+
if u isa AbstractNamedDimsArray
94+
return error("No arguments.")
95+
elseif u isa Mul
96+
return sorted_arguments(u)
97+
else
98+
return error("Variant not supported.")
99+
end
100+
end
101+
function TermInterface.sorted_children(a::LazyNamedDimsArray)
102+
return sorted_arguments(a)
103+
end
104+
40105
using Base.Broadcast: materialize
41106
function Base.Broadcast.materialize(a::LazyNamedDimsArray)
42-
if unwrap(a) isa AbstractNamedDimsArray
43-
return unwrap(a)
44-
elseif unwrap(a) isa Prod
45-
return prod(materialize, unwrap(a).factors)
107+
u = unwrap(a)
108+
if u isa AbstractNamedDimsArray
109+
return u
110+
elseif u isa Mul
111+
return mapfoldl(materialize, operation(u), arguments(u))
46112
else
47113
return error("Variant not supported.")
48114
end
49115
end
50116
Base.copy(a::LazyNamedDimsArray) = materialize(a)
51117

52118
function Base.:*(a::LazyNamedDimsArray)
53-
if unwrap(a) isa AbstractNamedDimsArray
54-
return LazyNamedDimsArray(Prod([lazy(unwrap(a))]))
55-
elseif unwrap(a) isa Prod
119+
u = unwrap(a)
120+
if u isa AbstractNamedDimsArray
121+
return LazyNamedDimsArray(Mul([lazy(u)]))
122+
elseif u isa Mul
56123
return a
57124
else
58125
return error("Variant not supported.")
@@ -61,7 +128,7 @@ end
61128

62129
function Base.:*(a1::LazyNamedDimsArray, a2::LazyNamedDimsArray)
63130
# Nested by default.
64-
return LazyNamedDimsArray(Prod([a1, a2]))
131+
return LazyNamedDimsArray(Mul([a1, a2]))
65132
end
66133
function Base.:+(a1::LazyNamedDimsArray, a2::LazyNamedDimsArray)
67134
return error("Not implemented.")
@@ -85,7 +152,7 @@ end
85152
function LazyNamedDimsArray(a::AbstractNamedDimsArray)
86153
return LazyNamedDimsArray{eltype(a), typeof(a)}(a)
87154
end
88-
function LazyNamedDimsArray(a::Prod{LazyNamedDimsArray{T, A}}) where {T, A}
155+
function LazyNamedDimsArray(a::Mul{LazyNamedDimsArray{T, A}}) where {T, A}
89156
return LazyNamedDimsArray{T, A}(a)
90157
end
91158
function lazy(a::AbstractNamedDimsArray)
@@ -124,59 +191,4 @@ function Broadcast.broadcasted(::LazyNamedDimsArrayStyle, ::typeof(-), a)
124191
return -a
125192
end
126193

127-
using TermInterface: TermInterface
128-
# arguments, arity, children, head, iscall, operation
129-
function TermInterface.arguments(a::LazyNamedDimsArray)
130-
if unwrap(a) isa AbstractNamedDimsArray
131-
return error("No arguments.")
132-
elseif unwrap(a) isa Prod
133-
unwrap(a).factors
134-
else
135-
return error("Variant not supported.")
136-
end
137-
end
138-
function TermInterface.children(a::LazyNamedDimsArray)
139-
return TermInterface.arguments(a)
140-
end
141-
function TermInterface.head(a::LazyNamedDimsArray)
142-
return TermInterface.operation(a)
143-
end
144-
function TermInterface.iscall(a::LazyNamedDimsArray)
145-
if unwrap(a) isa AbstractNamedDimsArray
146-
return false
147-
elseif unwrap(a) isa Prod
148-
return true
149-
else
150-
return false
151-
end
152-
end
153-
function TermInterface.isexpr(a::LazyNamedDimsArray)
154-
return TermInterface.iscall(a)
155-
end
156-
function TermInterface.maketerm(::Type{LazyNamedDimsArray}, head, args, metadata)
157-
if head prod
158-
return LazyNamedDimsArray(Prod(args))
159-
else
160-
return error("Only product terms supported right now.")
161-
end
162-
end
163-
function TermInterface.operation(a::LazyNamedDimsArray)
164-
if unwrap(a) isa AbstractNamedDimsArray
165-
return error("No operation.")
166-
elseif unwrap(a) isa Prod
167-
prod
168-
else
169-
return error("Variant not supported.")
170-
end
171-
end
172-
function TermInterface.sorted_arguments(a::LazyNamedDimsArray)
173-
if unwrap(a) isa AbstractNamedDimsArray
174-
return error("No arguments.")
175-
elseif unwrap(a) isa Prod
176-
return TermInterface.arguments(a)
177-
else
178-
return error("Variant not supported.")
179-
end
180-
end
181-
182194
end

test/test_lazynameddimsarrays.jl

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,17 @@
11
using Base.Broadcast: materialize
2-
using ITensorNetworksNext.LazyNamedDimsArrays: LazyNamedDimsArray, Prod, lazy
2+
using ITensorNetworksNext.LazyNamedDimsArrays: LazyNamedDimsArray, Mul, lazy
33
using NamedDimsArrays: NamedDimsArray, inds, nameddims
44
using TermInterface:
5-
arguments, arity, children, head, iscall, isexpr, maketerm, operation, sorted_arguments
5+
arguments,
6+
arity,
7+
children,
8+
head,
9+
iscall,
10+
isexpr,
11+
maketerm,
12+
operation,
13+
sorted_arguments,
14+
sorted_children
615
using Test: @test, @test_throws, @testset
716
using WrappedUnions: unwrap
817

@@ -23,8 +32,11 @@ using WrappedUnions: unwrap
2332
@test copy(l) a1 * a2 * a3
2433
@test materialize(l) a1 * a2 * a3
2534
@test issetequal(inds(l), symdiff(inds.((a1, a2, a3))...))
26-
@test unwrap(l) isa Prod
27-
@test unwrap(l).factors == [l1 * l2, l3]
35+
@test unwrap(l) isa Mul
36+
@test unwrap(l).arguments == [l1 * l2, l3]
37+
# TermInterface.jl
38+
@test operation(unwrap(l)) *
39+
@test arguments(unwrap(l)) == [l1 * l2, l3]
2840
end
2941

3042
@testset "TermInterface" begin
@@ -41,16 +53,18 @@ using WrappedUnions: unwrap
4153
@test !isexpr(l1)
4254
@test_throws ErrorException operation(l1)
4355
@test_throws ErrorException sorted_arguments(l1)
56+
@test_throws ErrorException sorted_children(l1)
4457

4558
l = l1 * l2 * l3
4659
@test arguments(l) == [l1 * l2, l3]
4760
@test arity(l) == 2
4861
@test children(l) == [l1 * l2, l3]
49-
@test head(l) prod
62+
@test head(l) *
5063
@test iscall(l)
5164
@test isexpr(l)
52-
@test l == maketerm(LazyNamedDimsArray, prod, [l1 * l2, l3], nothing)
53-
@test operation(l) prod
65+
@test l == maketerm(LazyNamedDimsArray, *, [l1 * l2, l3], nothing)
66+
@test operation(l) *
5467
@test sorted_arguments(l) == [l1 * l2, l3]
68+
@test sorted_children(l) == [l1 * l2, l3]
5569
end
5670
end

0 commit comments

Comments
 (0)