From f0b651f3eda5394032610e526dffac0c172a9d1b Mon Sep 17 00:00:00 2001 From: mtfishman Date: Mon, 6 Oct 2025 22:46:44 -0400 Subject: [PATCH 1/4] LazyNamedDimsArrays --- Project.toml | 6 +- src/ITensorNetworksNext.jl | 1 + src/lazynameddimsarrays.jl | 182 +++++++++++++++++++++++++++++++ test/Project.toml | 4 + test/test_lazynameddimsarrays.jl | 55 ++++++++++ 5 files changed, 247 insertions(+), 1 deletion(-) create mode 100644 src/lazynameddimsarrays.jl create mode 100644 test/test_lazynameddimsarrays.jl diff --git a/Project.toml b/Project.toml index fc1ba2d..b41b50a 100644 --- a/Project.toml +++ b/Project.toml @@ -15,9 +15,11 @@ NamedDimsArrays = "60cbd0c0-df58-4cb7-918c-6f5607b73fde" NamedGraphs = "678767b0-92e7-4007-89e4-4527a8725b19" SimpleTraits = "699a6c99-e7fa-54fc-8d76-47d257e15c1d" SplitApplyCombine = "03a91e81-4c3e-53e1-a0a4-9c0c8f19dd66" +TermInterface = "8ea1fca8-c5ef-4a55-8b96-4e9afe9c9a3c" +WrappedUnions = "325db55a-9c6c-5b90-b1a2-ec87e7a38c44" [compat] -Adapt = "4.3.0" +Adapt = "4.3" BackendSelection = "0.1.6" DataGraphs = "0.2.7" Dictionaries = "0.4.5" @@ -28,4 +30,6 @@ NamedDimsArrays = "0.8" NamedGraphs = "0.6.9, 0.7" SimpleTraits = "0.9.5" SplitApplyCombine = "1.2.3" +TermInterface = "2" +WrappedUnions = "0.3" julia = "1.10" diff --git a/src/ITensorNetworksNext.jl b/src/ITensorNetworksNext.jl index 89daa37..9aa9579 100644 --- a/src/ITensorNetworksNext.jl +++ b/src/ITensorNetworksNext.jl @@ -2,5 +2,6 @@ module ITensorNetworksNext include("abstracttensornetwork.jl") include("tensornetwork.jl") +include("lazynameddimsarrays.jl") end diff --git a/src/lazynameddimsarrays.jl b/src/lazynameddimsarrays.jl new file mode 100644 index 0000000..04eca26 --- /dev/null +++ b/src/lazynameddimsarrays.jl @@ -0,0 +1,182 @@ +module LazyNamedDimsArrays + +using WrappedUnions: @wrapped, unwrap +using NamedDimsArrays: + NamedDimsArrays, + AbstractNamedDimsArray, + AbstractNamedDimsArrayStyle, + dename, + inds + +struct Prod{A} + factors::Vector{A} +end + +@wrapped struct LazyNamedDimsArray{ + T, A <: AbstractNamedDimsArray{T}, + } <: AbstractNamedDimsArray{T, Any} + union::Union{A, Prod{LazyNamedDimsArray{T, A}}} +end + +function NamedDimsArrays.inds(a::LazyNamedDimsArray) + if unwrap(a) isa AbstractNamedDimsArray + return inds(unwrap(a)) + elseif unwrap(a) isa Prod + return mapreduce(inds, symdiff, unwrap(a).factors) + else + return error("Variant not supported.") + end +end +function NamedDimsArrays.dename(a::LazyNamedDimsArray) + if unwrap(a) isa AbstractNamedDimsArray + return dename(unwrap(a)) + elseif unwrap(a) isa Prod + return dename(materialize(a), inds(a)) + else + return error("Variant not supported.") + end +end + +using Base.Broadcast: materialize +function Base.Broadcast.materialize(a::LazyNamedDimsArray) + if unwrap(a) isa AbstractNamedDimsArray + return unwrap(a) + elseif unwrap(a) isa Prod + return prod(materialize, unwrap(a).factors) + else + return error("Variant not supported.") + end +end +Base.copy(a::LazyNamedDimsArray) = materialize(a) + +function Base.:*(a::LazyNamedDimsArray) + if unwrap(a) isa AbstractNamedDimsArray + return LazyNamedDimsArray(Prod([lazy(unwrap(a))])) + elseif unwrap(a) isa Prod + return a + else + return error("Variant not supported.") + end +end + +function Base.:*(a1::LazyNamedDimsArray, a2::LazyNamedDimsArray) + # Nested by default. + return LazyNamedDimsArray(Prod([a1, a2])) +end +function Base.:+(a1::LazyNamedDimsArray, a2::LazyNamedDimsArray) + return error("Not implemented.") +end +function Base.:-(a1::LazyNamedDimsArray, a2::LazyNamedDimsArray) + return error("Not implemented.") +end +function Base.:*(c::Number, a::LazyNamedDimsArray) + return error("Not implemented.") +end +function Base.:*(a::LazyNamedDimsArray, c::Number) + return error("Not implemented.") +end +function Base.:/(a::LazyNamedDimsArray, c::Number) + return error("Not implemented.") +end +function Base.:-(a::LazyNamedDimsArray) + return error("Not implemented.") +end + +function LazyNamedDimsArray(a::AbstractNamedDimsArray) + return LazyNamedDimsArray{eltype(a), typeof(a)}(a) +end +function LazyNamedDimsArray(a::Prod{LazyNamedDimsArray{T, A}}) where {T, A} + return LazyNamedDimsArray{T, A}(a) +end +function lazy(a::AbstractNamedDimsArray) + return LazyNamedDimsArray(a) +end + +# Broadcasting +struct LazyNamedDimsArrayStyle <: AbstractNamedDimsArrayStyle{Any} end +function Base.BroadcastStyle(::Type{<:LazyNamedDimsArray}) + return LazyNamedDimsArrayStyle() +end +function Broadcast.broadcasted(::LazyNamedDimsArrayStyle, f, as...) + return error("Arbitrary broadcasting not supported for LazyNamedDimsArray.") +end +# Linear operations. +function Broadcast.broadcasted(::LazyNamedDimsArrayStyle, ::typeof(+), a1, a2) + return a1 + a2 +end +function Broadcast.broadcasted(::LazyNamedDimsArrayStyle, ::typeof(-), a1, a2) + return a1 - a2 +end +function Broadcast.broadcasted(::LazyNamedDimsArrayStyle, ::typeof(*), c::Number, a) + return c * a +end +function Broadcast.broadcasted(::LazyNamedDimsArrayStyle, ::typeof(*), a, c::Number) + return a * c +end +# Fix ambiguity error. +function Broadcast.broadcasted(::LazyNamedDimsArrayStyle, ::typeof(*), a::Number, b::Number) + return a * b +end +function Broadcast.broadcasted(::LazyNamedDimsArrayStyle, ::typeof(/), a, c::Number) + return a / c +end +function Broadcast.broadcasted(::LazyNamedDimsArrayStyle, ::typeof(-), a) + return -a +end + +using TermInterface: TermInterface +# arguments, arity, children, head, iscall, operation +function TermInterface.arguments(a::LazyNamedDimsArray) + if unwrap(a) isa AbstractNamedDimsArray + return error("No arguments.") + elseif unwrap(a) isa Prod + unwrap(a).factors + else + return error("Variant not supported.") + end +end +function TermInterface.children(a::LazyNamedDimsArray) + return TermInterface.arguments(a) +end +function TermInterface.head(a::LazyNamedDimsArray) + return TermInterface.operation(a) +end +function TermInterface.iscall(a::LazyNamedDimsArray) + if unwrap(a) isa AbstractNamedDimsArray + return false + elseif unwrap(a) isa Prod + return true + else + return false + end +end +function TermInterface.isexpr(a::LazyNamedDimsArray) + return TermInterface.iscall(a) +end +function TermInterface.maketerm(::Type{LazyNamedDimsArray}, head, args, metadata) + if head ≡ prod + return LazyNamedDimsArray(Prod(args)) + else + return error("Only product terms supported right now.") + end +end +function TermInterface.operation(a::LazyNamedDimsArray) + if unwrap(a) isa AbstractNamedDimsArray + return error("No operation.") + elseif unwrap(a) isa Prod + prod + else + return error("Variant not supported.") + end +end +function TermInterface.sorted_arguments(a::LazyNamedDimsArray) + if unwrap(a) isa AbstractNamedDimsArray + return error("No arguments.") + elseif unwrap(a) isa Prod + return TermInterface.arguments(a) + else + return error("Variant not supported.") + end +end + +end diff --git a/test/Project.toml b/test/Project.toml index b22f9d1..9646508 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -8,7 +8,9 @@ NamedDimsArrays = "60cbd0c0-df58-4cb7-918c-6f5607b73fde" NamedGraphs = "678767b0-92e7-4007-89e4-4527a8725b19" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb" +TermInterface = "8ea1fca8-c5ef-4a55-8b96-4e9afe9c9a3c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +WrappedUnions = "325db55a-9c6c-5b90-b1a2-ec87e7a38c44" [compat] Aqua = "0.8.14" @@ -20,4 +22,6 @@ NamedDimsArrays = "0.8" NamedGraphs = "0.6.8, 0.7" SafeTestsets = "0.1" Suppressor = "0.2.8" +TermInterface = "2" Test = "1.10" +WrappedUnions = "0.3" diff --git a/test/test_lazynameddimsarrays.jl b/test/test_lazynameddimsarrays.jl new file mode 100644 index 0000000..10ca751 --- /dev/null +++ b/test/test_lazynameddimsarrays.jl @@ -0,0 +1,55 @@ +using Base.Broadcast: materialize +using ITensorNetworksNext.LazyNamedDimsArrays: LazyNamedDimsArray, Prod, lazy +using NamedDimsArrays: NamedDimsArray, inds, nameddims +using TermInterface: arguments, arity, children, head, iscall, isexpr, maketerm, operation +using Test: @test, @test_throws, @testset +using WrappedUnions: unwrap + +@testset "LazyNamedDimsArrays" begin + @testset "Basics" begin + a1 = nameddims(randn(2, 2), (:i, :j)) + a2 = nameddims(randn(2, 2), (:j, :k)) + a3 = nameddims(randn(2, 2), (:k, :l)) + l1, l2, l3 = lazy.((a1, a2, a3)) + for li in (l1, l2, l3) + @test li isa LazyNamedDimsArray + @test unwrap(li) isa NamedDimsArray + @test inds(li) == inds(unwrap(li)) + @test copy(li) == unwrap(li) + @test materialize(li) == unwrap(li) + end + l = l1 * l2 * l3 + @test copy(l) ≈ a1 * a2 * a3 + @test materialize(l) ≈ a1 * a2 * a3 + @test issetequal(inds(l), symdiff(inds.((a1, a2, a3))...)) + @test unwrap(l) isa Prod + @test unwrap(l).factors == [l1 * l2, l3] + end + + @testset "TermInterface" begin + a1 = nameddims(randn(2, 2), (:i, :j)) + a2 = nameddims(randn(2, 2), (:j, :k)) + a3 = nameddims(randn(2, 2), (:k, :l)) + l1, l2, l3 = lazy.((a1, a2, a3)) + + @test_throws ErrorException arguments(l1) + @test_throws ErrorException arity(l1) + @test_throws ErrorException children(l1) + @test_throws ErrorException head(l1) + @test !iscall(l1) + @test !isexpr(l1) + @test_throws ErrorException operation(l1) + @test_throws ErrorException sorted_arguments(l1) + + l = l1 * l2 * l3 + @test arguments(l) == [l1 * l2, l3] + @test arity(l) == 2 + @test children(l) == [l1 * l2, l3] + @test head(l) ≡ prod + @test iscall(l) + @test isexpr(l) + @test l == maketerm(LazyNamedDimsArray, prod, [l1 * l2, l3], nothing) + @test operation(l) ≡ prod + @test sorted_arguments(l) == [l1 * l2, l3] + end +end From 9c211bf5ac363f1ebbdc3cccd85cc46d2f3fe5f1 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Mon, 6 Oct 2025 22:47:11 -0400 Subject: [PATCH 2/4] Bump version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index b41b50a..b527a53 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ITensorNetworksNext" uuid = "302f2e75-49f0-4526-aef7-d8ba550cb06c" authors = ["ITensor developers and contributors"] -version = "0.1.5" +version = "0.1.6" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" From bf831e9e9af3b858abda27f6e14d8024e8958b73 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Mon, 6 Oct 2025 22:50:54 -0400 Subject: [PATCH 3/4] Fix tests --- test/test_lazynameddimsarrays.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/test_lazynameddimsarrays.jl b/test/test_lazynameddimsarrays.jl index 10ca751..958c191 100644 --- a/test/test_lazynameddimsarrays.jl +++ b/test/test_lazynameddimsarrays.jl @@ -1,7 +1,8 @@ using Base.Broadcast: materialize using ITensorNetworksNext.LazyNamedDimsArrays: LazyNamedDimsArray, Prod, lazy using NamedDimsArrays: NamedDimsArray, inds, nameddims -using TermInterface: arguments, arity, children, head, iscall, isexpr, maketerm, operation +using TermInterface: + arguments, arity, children, head, iscall, isexpr, maketerm, operation, sorted_arguments using Test: @test, @test_throws, @testset using WrappedUnions: unwrap From 52926c05b6df540e7c1062317a4275d72bf89caf Mon Sep 17 00:00:00 2001 From: mtfishman Date: Mon, 6 Oct 2025 22:55:41 -0400 Subject: [PATCH 4/4] Reorder includes --- src/ITensorNetworksNext.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ITensorNetworksNext.jl b/src/ITensorNetworksNext.jl index 9aa9579..35c9e59 100644 --- a/src/ITensorNetworksNext.jl +++ b/src/ITensorNetworksNext.jl @@ -1,7 +1,7 @@ module ITensorNetworksNext +include("lazynameddimsarrays.jl") include("abstracttensornetwork.jl") include("tensornetwork.jl") -include("lazynameddimsarrays.jl") end