Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "NamedDimsArrays"
uuid = "60cbd0c0-df58-4cb7-918c-6f5607b73fde"
authors = ["ITensor developers <[email protected]> and contributors"]
version = "0.7.16"
version = "0.7.17"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down Expand Up @@ -37,7 +37,7 @@ LinearAlgebra = "1.10"
MapBroadcast = "0.1.6"
Random = "1.10"
SimpleTraits = "0.9.4"
TensorAlgebra = "0.3, 0.4"
TensorAlgebra = "0.4.1"
TupleTools = "1.6.0"
TypeParameterAccessors = "0.4"
julia = "1.10"
2 changes: 1 addition & 1 deletion docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,6 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[compat]
NamedDimsArrays = "0.7"
TensorAlgebra = "0.3, 0.4"
TensorAlgebra = "0.4"
Literate = "2"
Documenter = "1"
2 changes: 1 addition & 1 deletion examples/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[compat]
NamedDimsArrays = "0.7"
TensorAlgebra = "0.3, 0.4"
TensorAlgebra = "0.4"
5 changes: 5 additions & 0 deletions src/abstractnameddimsarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,11 @@ function replacenameddimsindices(a::AbstractNamedDimsArray, replacements::Pair..
new_nameddimsindices = named.(dename.(old_nameddimsindices), last.(replacements))
return replacenameddimsindices(a, (old_nameddimsindices .=> new_nameddimsindices)...)
end
function replacenameddimsindices(a::AbstractNamedDimsArray, replacements::Dict)
return replacenameddimsindices(a) do name
return get(replacements, name, name)
end
end
function mapnameddimsindices(f, a::AbstractNamedDimsArray)
return setnameddimsindices(a, map(f, nameddimsindices(a)))
end
Expand Down
9 changes: 9 additions & 0 deletions src/nameddimsoperator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,15 @@ Base.:*(a::AbstractNamedDimsOperator, b::AbstractNamedDimsOperator) = state(a) *
Base.:*(a::AbstractNamedDimsOperator, b::AbstractNamedDimsArray) = state(a) * state(b)
Base.:*(a::AbstractNamedDimsArray, b::AbstractNamedDimsOperator) = state(a) * state(b)

for f in MATRIX_FUNCTIONS
@eval begin
function Base.$f(a::AbstractNamedDimsOperator)
c = codomain(a)
d = domain(a)
return operator($f(state(a), c, d), c .=> d)
end
end
end
struct NamedDimsOperator{T,N,P<:AbstractNamedDimsArray{T,N},D,C} <:
AbstractNamedDimsOperator{T,N}
parent::P
Expand Down
75 changes: 68 additions & 7 deletions src/tensoralgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ using TensorAlgebra:
blockedperm,
contract,
contract!,
contractadd!,
eigen,
eigvals,
factorize,
Expand All @@ -25,14 +26,14 @@ using TensorAlgebra:
using TensorAlgebra.BaseExtensions: BaseExtensions
using TupleTools: TupleTools

function TensorAlgebra.contract!(
function TensorAlgebra.contractadd!(
a_dest::AbstractNamedDimsArray,
a1::AbstractNamedDimsArray,
a2::AbstractNamedDimsArray,
α::Number=true,
β::Number=false,
α::Number,
β::Number,
)
contract!(
contractadd!(
dename(a_dest),
nameddimsindices(a_dest),
dename(a1),
Expand All @@ -45,6 +46,12 @@ function TensorAlgebra.contract!(
return a_dest
end

function TensorAlgebra.contract!(
a_dest::AbstractNamedDimsArray, a1::AbstractNamedDimsArray, a2::AbstractNamedDimsArray
)
return contractadd!(a_dest, a1, a2, true, false)
end

function TensorAlgebra.contract(a1::AbstractNamedDimsArray, a2::AbstractNamedDimsArray)
a_dest, nameddimsindices_dest = contract(
dename(a1), nameddimsindices(a1), dename(a2), nameddimsindices(a2)
Expand Down Expand Up @@ -79,10 +86,17 @@ function LinearAlgebra.mul!(
a_dest::AbstractNamedDimsArray,
a1::AbstractNamedDimsArray,
a2::AbstractNamedDimsArray,
α::Number=true,
β::Number=false,
α::Number,
β::Number,
)
contract!(a_dest, a1, a2, α, β)
contractadd!(a_dest, a1, a2, α, β)
return a_dest
end

function LinearAlgebra.mul!(
a_dest::AbstractNamedDimsArray, a1::AbstractNamedDimsArray, a2::AbstractNamedDimsArray
)
contract!(a_dest, a1, a2)
return a_dest
end

Expand Down Expand Up @@ -301,3 +315,50 @@ function TensorAlgebra.right_null(a::AbstractNamedDimsArray, dimnames_codomain;
domain = setdiff(nameddimsindices(a), codomain)
return right_null(a, codomain, domain; kwargs...)
end

const MATRIX_FUNCTIONS = [
:exp,
:cis,
:log,
:sqrt,
:cbrt,
:cos,
:sin,
:tan,
:csc,
:sec,
:cot,
:cosh,
:sinh,
:tanh,
:csch,
:sech,
:coth,
:acos,
:asin,
:atan,
:acsc,
:asec,
:acot,
:acosh,
:asinh,
:atanh,
:acsch,
:asech,
:acoth,
]

for f in MATRIX_FUNCTIONS
@eval begin
function Base.$f(
a::AbstractNamedDimsArray, dimnames_codomain, dimnames_domain; kwargs...
)
codomain = to_nameddimsindices(a, dimnames_codomain)
domain = to_nameddimsindices(a, dimnames_domain)
fa_unnamed = TensorAlgebra.$f(
dename(a), nameddimsindices(a), codomain, domain; kwargs...
)
return nameddimsarray(fa_unnamed, (codomain..., domain...))
end
end
end
6 changes: 3 additions & 3 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ GradedArrays = "bc96ca6e-b7c8-4bb6-888e-c93f838762c2"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
NamedDimsArrays = "60cbd0c0-df58-4cb7-918c-6f5607b73fde"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb"
TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Expand All @@ -16,11 +17,10 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Adapt = "4"
Aqua = "0.8.9"
BlockArrays = "1"
BlockSparseArrays = "0.8, 0.9, 0.10"
BlockSparseArrays = "0.10"
Combinatorics = "1"
GradedArrays = "0.4"
NamedDimsArrays = "0.7"
SafeTestsets = "0.1"
Suppressor = "0.2"
TensorAlgebra = "0.3, 0.4"
TensorAlgebra = "0.4"
Test = "1.10"
19 changes: 18 additions & 1 deletion test/test_tensoralgebra.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using LinearAlgebra: factorize, lq, norm, qr, svd
using NamedDimsArrays: dename, nameddimsindices, namedoneto
using NamedDimsArrays: NamedDimsArrays, dename, nameddimsindices, namedoneto
using StableRNGs: StableRNG
using TensorAlgebra:
TensorAlgebra,
contract,
Expand All @@ -14,6 +15,7 @@ using TensorAlgebra:
right_polar,
unmatricize
using Test: @test, @testset, @test_broken

elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
@testset "TensorAlgebra (eltype=$(elt))" for elt in elts
@testset "contract" begin
Expand Down Expand Up @@ -45,6 +47,21 @@ elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
@test dename(na_split, ("k", "i", "j", "l")) ≈
reshape(dename(na, ("a", "b")), (dename(k), dename(i), dename(j), dename(l)))
end
@testset "Matrix functions" begin
for f in NamedDimsArrays.MATRIX_FUNCTIONS
f == :cbrt && elt <: Complex && continue
f == :cbrt && VERSION < v"1.11-" && continue
@eval begin
i, j, k, l = namedoneto.((2, 2, 2, 2), ("i", "j", "k", "l"))
rng = StableRNG(123)
a = randn(rng, $elt, (i, j, k, l))
fa = $f(a, (j, l), (k, i))
m = dename(matricize(a, (j, l) => "a", (k, i) => "b"), ("a", "b"))
fm = dename(matricize(fa, (j, l) => "a", (k, i) => "b"), ("a", "b"))
@test fm ≈ $f(m)
end
end
end
@testset "qr/lq" begin
dims = (2, 2, 2, 2)
i, j, k, l = namedoneto.(dims, ("i", "j", "k", "l"))
Expand Down
Loading