Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "NamedDimsArrays"
uuid = "60cbd0c0-df58-4cb7-918c-6f5607b73fde"
authors = ["ITensor developers <[email protected]> and contributors"]
version = "0.7.16"
version = "0.7.17"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down Expand Up @@ -37,7 +37,7 @@ LinearAlgebra = "1.10"
MapBroadcast = "0.1.6"
Random = "1.10"
SimpleTraits = "0.9.4"
TensorAlgebra = "0.3, 0.4"
TensorAlgebra = "0.4.1"
TupleTools = "1.6.0"
TypeParameterAccessors = "0.4"
julia = "1.10"
2 changes: 1 addition & 1 deletion docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,6 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[compat]
NamedDimsArrays = "0.7"
TensorAlgebra = "0.3, 0.4"
TensorAlgebra = "0.4"
Literate = "2"
Documenter = "1"
2 changes: 1 addition & 1 deletion examples/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[compat]
NamedDimsArrays = "0.7"
TensorAlgebra = "0.3, 0.4"
TensorAlgebra = "0.4"
5 changes: 5 additions & 0 deletions src/abstractnameddimsarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,11 @@ function replacenameddimsindices(a::AbstractNamedDimsArray, replacements::Pair..
new_nameddimsindices = named.(dename.(old_nameddimsindices), last.(replacements))
return replacenameddimsindices(a, (old_nameddimsindices .=> new_nameddimsindices)...)
end
function replacenameddimsindices(a::AbstractNamedDimsArray, replacements::Dict)
return replacenameddimsindices(a) do name
return get(replacements, name, name)
end
end
function mapnameddimsindices(f, a::AbstractNamedDimsArray)
return setnameddimsindices(a, map(f, nameddimsindices(a)))
end
Expand Down
9 changes: 9 additions & 0 deletions src/nameddimsoperator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,15 @@ Base.:*(a::AbstractNamedDimsOperator, b::AbstractNamedDimsOperator) = state(a) *
Base.:*(a::AbstractNamedDimsOperator, b::AbstractNamedDimsArray) = state(a) * state(b)
Base.:*(a::AbstractNamedDimsArray, b::AbstractNamedDimsOperator) = state(a) * state(b)

for f in MATRIX_FUNCTIONS
@eval begin
function Base.$f(a::AbstractNamedDimsOperator)
c = codomain(a)
d = domain(a)
return operator($f(state(a), c, d), c .=> d)
end
end
end
struct NamedDimsOperator{T,N,P<:AbstractNamedDimsArray{T,N},D,C} <:
AbstractNamedDimsOperator{T,N}
parent::P
Expand Down
75 changes: 68 additions & 7 deletions src/tensoralgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ using TensorAlgebra:
blockedperm,
contract,
contract!,
contractadd!,
eigen,
eigvals,
factorize,
Expand All @@ -25,14 +26,14 @@ using TensorAlgebra:
using TensorAlgebra.BaseExtensions: BaseExtensions
using TupleTools: TupleTools

function TensorAlgebra.contract!(
function TensorAlgebra.contractadd!(
a_dest::AbstractNamedDimsArray,
a1::AbstractNamedDimsArray,
a2::AbstractNamedDimsArray,
α::Number=true,
β::Number=false,
α::Number,
β::Number,
)
contract!(
contractadd!(
dename(a_dest),
nameddimsindices(a_dest),
dename(a1),
Expand All @@ -45,6 +46,12 @@ function TensorAlgebra.contract!(
return a_dest
end

function TensorAlgebra.contract!(
a_dest::AbstractNamedDimsArray, a1::AbstractNamedDimsArray, a2::AbstractNamedDimsArray
)
return contractadd!(a_dest, a1, a2, true, false)
end

function TensorAlgebra.contract(a1::AbstractNamedDimsArray, a2::AbstractNamedDimsArray)
a_dest, nameddimsindices_dest = contract(
dename(a1), nameddimsindices(a1), dename(a2), nameddimsindices(a2)
Expand Down Expand Up @@ -79,10 +86,17 @@ function LinearAlgebra.mul!(
a_dest::AbstractNamedDimsArray,
a1::AbstractNamedDimsArray,
a2::AbstractNamedDimsArray,
α::Number=true,
β::Number=false,
α::Number,
β::Number,
)
contract!(a_dest, a1, a2, α, β)
contractadd!(a_dest, a1, a2, α, β)
return a_dest
end

function LinearAlgebra.mul!(
a_dest::AbstractNamedDimsArray, a1::AbstractNamedDimsArray, a2::AbstractNamedDimsArray
)
contract!(a_dest, a1, a2)
return a_dest
end

Expand Down Expand Up @@ -301,3 +315,50 @@ function TensorAlgebra.right_null(a::AbstractNamedDimsArray, dimnames_codomain;
domain = setdiff(nameddimsindices(a), codomain)
return right_null(a, codomain, domain; kwargs...)
end

const MATRIX_FUNCTIONS = [
:exp,
:cis,
:log,
:sqrt,
:cbrt,
:cos,
:sin,
:tan,
:csc,
:sec,
:cot,
:cosh,
:sinh,
:tanh,
:csch,
:sech,
:coth,
:acos,
:asin,
:atan,
:acsc,
:asec,
:acot,
:acosh,
:asinh,
:atanh,
:acsch,
:asech,
:acoth,
]

for f in MATRIX_FUNCTIONS
@eval begin
function Base.$f(
a::AbstractNamedDimsArray, dimnames_codomain, dimnames_domain; kwargs...
)
codomain = to_nameddimsindices(a, dimnames_codomain)
domain = to_nameddimsindices(a, dimnames_domain)
fa_unnamed = TensorAlgebra.$f(
dename(a), nameddimsindices(a), codomain, domain; kwargs...
)
return nameddimsarray(fa_unnamed, (codomain..., domain...))
end
end
end
5 changes: 3 additions & 2 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ GradedArrays = "bc96ca6e-b7c8-4bb6-888e-c93f838762c2"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
NamedDimsArrays = "60cbd0c0-df58-4cb7-918c-6f5607b73fde"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb"
TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Expand All @@ -16,11 +17,11 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Adapt = "4"
Aqua = "0.8.9"
BlockArrays = "1"
BlockSparseArrays = "0.8, 0.9, 0.10"
BlockSparseArrays = "0.10"
Combinatorics = "1"
GradedArrays = "0.4"
NamedDimsArrays = "0.7"
SafeTestsets = "0.1"
Suppressor = "0.2"
TensorAlgebra = "0.3, 0.4"
TensorAlgebra = "0.4"
Test = "1.10"
19 changes: 18 additions & 1 deletion test/test_tensoralgebra.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using LinearAlgebra: factorize, lq, norm, qr, svd
using NamedDimsArrays: dename, nameddimsindices, namedoneto
using NamedDimsArrays: NamedDimsArrays, dename, nameddimsindices, namedoneto
using StableRNGs: StableRNG
using TensorAlgebra:
TensorAlgebra,
contract,
Expand All @@ -14,6 +15,7 @@ using TensorAlgebra:
right_polar,
unmatricize
using Test: @test, @testset, @test_broken

elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
@testset "TensorAlgebra (eltype=$(elt))" for elt in elts
@testset "contract" begin
Expand Down Expand Up @@ -45,6 +47,21 @@ elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
@test dename(na_split, ("k", "i", "j", "l")) ≈
reshape(dename(na, ("a", "b")), (dename(k), dename(i), dename(j), dename(l)))
end
@testset "Matrix functions" begin
for f in NamedDimsArrays.MATRIX_FUNCTIONS
f == :cbrt && elt <: Complex && continue
f == :cbrt && VERSION < v"1.11-" && continue
@eval begin
i, j, k, l = namedoneto.((2, 2, 2, 2), ("i", "j", "k", "l"))
rng = StableRNG(123)
a = randn(rng, $elt, (i, j, k, l))
fa = $f(a, (j, l), (k, i))
m = dename(matricize(a, (j, l) => "a", (k, i) => "b"), ("a", "b"))
fm = dename(matricize(fa, (j, l) => "a", (k, i) => "b"), ("a", "b"))
@test fm ≈ $f(m)
end
end
end
@testset "qr/lq" begin
dims = (2, 2, 2, 2)
i, j, k, l = namedoneto.(dims, ("i", "j", "k", "l"))
Expand Down
Loading