Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "ITensorNetworksNext"
uuid = "302f2e75-49f0-4526-aef7-d8ba550cb06c"
authors = ["ITensor developers <[email protected]> and contributors"]
version = "0.1.6"
version = "0.1.7"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down
162 changes: 87 additions & 75 deletions src/lazynameddimsarrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,52 +7,119 @@ using NamedDimsArrays:
AbstractNamedDimsArrayStyle,
dename,
inds
using TermInterface: TermInterface, arguments, iscall, maketerm, operation, sorted_arguments

struct Prod{A}
factors::Vector{A}
end
struct Mul{A}
arguments::Vector{A}
end
TermInterface.arguments(m::Mul) = getfield(m, :arguments)
TermInterface.children(m::Mul) = arguments(m)
TermInterface.head(m::Mul) = operation(m)
TermInterface.iscall(m::Mul) = true
TermInterface.isexpr(m::Mul) = iscall(m)
TermInterface.maketerm(::Type{Mul}, head::typeof(*), args, metadata) = Mul(args)
TermInterface.operation(m::Mul) = *
TermInterface.sorted_arguments(m::Mul) = arguments(m)
TermInterface.sorted_children(m::Mul) = sorted_arguments(a)

@wrapped struct LazyNamedDimsArray{
T, A <: AbstractNamedDimsArray{T},
} <: AbstractNamedDimsArray{T, Any}
union::Union{A, Prod{LazyNamedDimsArray{T, A}}}
union::Union{A, Mul{LazyNamedDimsArray{T, A}}}
end

function NamedDimsArrays.inds(a::LazyNamedDimsArray)
if unwrap(a) isa AbstractNamedDimsArray
return inds(unwrap(a))
elseif unwrap(a) isa Prod
return mapreduce(inds, symdiff, unwrap(a).factors)
u = unwrap(a)
if u isa AbstractNamedDimsArray
return inds(u)
elseif u isa Mul
return mapreduce(inds, symdiff, arguments(u))
else
return error("Variant not supported.")
end
end
function NamedDimsArrays.dename(a::LazyNamedDimsArray)
if unwrap(a) isa AbstractNamedDimsArray
return dename(unwrap(a))
elseif unwrap(a) isa Prod
u = unwrap(a)
if u isa AbstractNamedDimsArray
return dename(u)
elseif u isa Mul
return dename(materialize(a), inds(a))
else
return error("Variant not supported.")
end
end

function TermInterface.arguments(a::LazyNamedDimsArray)
u = unwrap(a)
if u isa AbstractNamedDimsArray
return error("No arguments.")
elseif u isa Mul
return arguments(u)
else
return error("Variant not supported.")
end
end
function TermInterface.children(a::LazyNamedDimsArray)
return arguments(a)
end
function TermInterface.head(a::LazyNamedDimsArray)
return operation(a)
end
function TermInterface.iscall(a::LazyNamedDimsArray)
return iscall(unwrap(a))
end
function TermInterface.isexpr(a::LazyNamedDimsArray)
return iscall(a)
end
function TermInterface.maketerm(::Type{LazyNamedDimsArray}, head, args, metadata)
if head ≡ *
return LazyNamedDimsArray(maketerm(Mul, head, args, metadata))
else
return error("Only product terms supported right now.")
end
end
function TermInterface.operation(a::LazyNamedDimsArray)
u = unwrap(a)
if u isa AbstractNamedDimsArray
return error("No operation.")
elseif u isa Mul
return operation(u)
else
return error("Variant not supported.")
end
end
function TermInterface.sorted_arguments(a::LazyNamedDimsArray)
u = unwrap(a)
if u isa AbstractNamedDimsArray
return error("No arguments.")
elseif u isa Mul
return sorted_arguments(u)
else
return error("Variant not supported.")
end
end
function TermInterface.sorted_children(a::LazyNamedDimsArray)
return sorted_arguments(a)
end

using Base.Broadcast: materialize
function Base.Broadcast.materialize(a::LazyNamedDimsArray)
if unwrap(a) isa AbstractNamedDimsArray
return unwrap(a)
elseif unwrap(a) isa Prod
return prod(materialize, unwrap(a).factors)
u = unwrap(a)
if u isa AbstractNamedDimsArray
return u
elseif u isa Mul
return mapfoldl(materialize, operation(u), arguments(u))
else
return error("Variant not supported.")
end
end
Base.copy(a::LazyNamedDimsArray) = materialize(a)

function Base.:*(a::LazyNamedDimsArray)
if unwrap(a) isa AbstractNamedDimsArray
return LazyNamedDimsArray(Prod([lazy(unwrap(a))]))
elseif unwrap(a) isa Prod
u = unwrap(a)
if u isa AbstractNamedDimsArray
return LazyNamedDimsArray(Mul([lazy(u)]))
elseif u isa Mul
return a
else
return error("Variant not supported.")
Expand All @@ -61,7 +128,7 @@ end

function Base.:*(a1::LazyNamedDimsArray, a2::LazyNamedDimsArray)
# Nested by default.
return LazyNamedDimsArray(Prod([a1, a2]))
return LazyNamedDimsArray(Mul([a1, a2]))
end
function Base.:+(a1::LazyNamedDimsArray, a2::LazyNamedDimsArray)
return error("Not implemented.")
Expand All @@ -85,7 +152,7 @@ end
function LazyNamedDimsArray(a::AbstractNamedDimsArray)
return LazyNamedDimsArray{eltype(a), typeof(a)}(a)
end
function LazyNamedDimsArray(a::Prod{LazyNamedDimsArray{T, A}}) where {T, A}
function LazyNamedDimsArray(a::Mul{LazyNamedDimsArray{T, A}}) where {T, A}
return LazyNamedDimsArray{T, A}(a)
end
function lazy(a::AbstractNamedDimsArray)
Expand Down Expand Up @@ -124,59 +191,4 @@ function Broadcast.broadcasted(::LazyNamedDimsArrayStyle, ::typeof(-), a)
return -a
end

using TermInterface: TermInterface
# arguments, arity, children, head, iscall, operation
function TermInterface.arguments(a::LazyNamedDimsArray)
if unwrap(a) isa AbstractNamedDimsArray
return error("No arguments.")
elseif unwrap(a) isa Prod
unwrap(a).factors
else
return error("Variant not supported.")
end
end
function TermInterface.children(a::LazyNamedDimsArray)
return TermInterface.arguments(a)
end
function TermInterface.head(a::LazyNamedDimsArray)
return TermInterface.operation(a)
end
function TermInterface.iscall(a::LazyNamedDimsArray)
if unwrap(a) isa AbstractNamedDimsArray
return false
elseif unwrap(a) isa Prod
return true
else
return false
end
end
function TermInterface.isexpr(a::LazyNamedDimsArray)
return TermInterface.iscall(a)
end
function TermInterface.maketerm(::Type{LazyNamedDimsArray}, head, args, metadata)
if head ≡ prod
return LazyNamedDimsArray(Prod(args))
else
return error("Only product terms supported right now.")
end
end
function TermInterface.operation(a::LazyNamedDimsArray)
if unwrap(a) isa AbstractNamedDimsArray
return error("No operation.")
elseif unwrap(a) isa Prod
prod
else
return error("Variant not supported.")
end
end
function TermInterface.sorted_arguments(a::LazyNamedDimsArray)
if unwrap(a) isa AbstractNamedDimsArray
return error("No arguments.")
elseif unwrap(a) isa Prod
return TermInterface.arguments(a)
else
return error("Variant not supported.")
end
end

end
28 changes: 21 additions & 7 deletions test/test_lazynameddimsarrays.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,17 @@
using Base.Broadcast: materialize
using ITensorNetworksNext.LazyNamedDimsArrays: LazyNamedDimsArray, Prod, lazy
using ITensorNetworksNext.LazyNamedDimsArrays: LazyNamedDimsArray, Mul, lazy
using NamedDimsArrays: NamedDimsArray, inds, nameddims
using TermInterface:
arguments, arity, children, head, iscall, isexpr, maketerm, operation, sorted_arguments
arguments,
arity,
children,
head,
iscall,
isexpr,
maketerm,
operation,
sorted_arguments,
sorted_children
using Test: @test, @test_throws, @testset
using WrappedUnions: unwrap

Expand All @@ -23,8 +32,11 @@ using WrappedUnions: unwrap
@test copy(l) ≈ a1 * a2 * a3
@test materialize(l) ≈ a1 * a2 * a3
@test issetequal(inds(l), symdiff(inds.((a1, a2, a3))...))
@test unwrap(l) isa Prod
@test unwrap(l).factors == [l1 * l2, l3]
@test unwrap(l) isa Mul
@test unwrap(l).arguments == [l1 * l2, l3]
# TermInterface.jl
@test operation(unwrap(l)) ≡ *
@test arguments(unwrap(l)) == [l1 * l2, l3]
end

@testset "TermInterface" begin
Expand All @@ -41,16 +53,18 @@ using WrappedUnions: unwrap
@test !isexpr(l1)
@test_throws ErrorException operation(l1)
@test_throws ErrorException sorted_arguments(l1)
@test_throws ErrorException sorted_children(l1)

l = l1 * l2 * l3
@test arguments(l) == [l1 * l2, l3]
@test arity(l) == 2
@test children(l) == [l1 * l2, l3]
@test head(l) ≡ prod
@test head(l) ≡ *
@test iscall(l)
@test isexpr(l)
@test l == maketerm(LazyNamedDimsArray, prod, [l1 * l2, l3], nothing)
@test operation(l) ≡ prod
@test l == maketerm(LazyNamedDimsArray, *, [l1 * l2, l3], nothing)
@test operation(l) ≡ *
@test sorted_arguments(l) == [l1 * l2, l3]
@test sorted_children(l) == [l1 * l2, l3]
end
end
Loading