Skip to content

Commit 5b074f1

Browse files
committed
Try fixing tests
1 parent 5bcd80f commit 5b074f1

File tree

6 files changed

+23
-3
lines changed

6 files changed

+23
-3
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
99
BackendSelection = "680c2d7c-f67a-4cc9-ae9c-da132b1447a5"
1010
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
1111
DataGraphs = "b5a273c3-7e6c-41f6-98bd-8d7f1525a36a"
12+
DerivableInterfaces = "6c5e35bf-e59e-4898-b73c-732dcc4ba65f"
1213
DiagonalArrays = "74fd4be6-21e2-4f6f-823a-4360d37c7a77"
1314
Dictionaries = "85a47980-9c8c-11e8-2b9f-f7ca1fa99fb4"
1415
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
@@ -34,6 +35,7 @@ Adapt = "4.3"
3435
BackendSelection = "0.1.6"
3536
Combinatorics = "1"
3637
DataGraphs = "0.2.7"
38+
DerivableInterfaces = "0.5.5"
3739
DiagonalArrays = "0.3.23"
3840
Dictionaries = "0.4.5"
3941
Graphs = "1.13.1"

src/LazyNamedDimsArrays/lazyinterface.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,16 @@ function equals_lazy(a1, a2)
109109
return false
110110
end
111111
end
112+
function isequal_lazy(a1, a2)
113+
u1, u2 = unwrap.((a1, a2))
114+
if !iscall(u1) && !iscall(u2)
115+
return isequal(u1, u2)
116+
elseif ismul(u1) && ismul(u2)
117+
return isequal(arguments(u1), arguments(u2))
118+
else
119+
return false
120+
end
121+
end
112122
function hash_lazy(a, h::UInt64)
113123
h = hash(Symbol(unspecify_type_parameters(typeof(a))), h)
114124
# Use `_hash`, which defines a custom hash for NamedDimsArray.

src/LazyNamedDimsArrays/lazynameddimsarray.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ AbstractTrees.nodevalue(a::LazyNamedDimsArray) = nodevalue_lazy(a)
5050
Base.Broadcast.materialize(a::LazyNamedDimsArray) = materialize_lazy(a)
5151
Base.copy(a::LazyNamedDimsArray) = copy_lazy(a)
5252
Base.:(==)(a1::LazyNamedDimsArray, a2::LazyNamedDimsArray) = equals_lazy(a1, a2)
53+
Base.isequal(a1::LazyNamedDimsArray, a2::LazyNamedDimsArray) = isequal_lazy(a1, a2)
5354
Base.hash(a::LazyNamedDimsArray, h::UInt64) = hash_lazy(a, h)
5455
map_arguments(f, a::LazyNamedDimsArray) = map_arguments_lazy(f, a)
5556
substitute(a::LazyNamedDimsArray, substitutions) = substitute_lazy(a, substitutions)

src/LazyNamedDimsArrays/symbolicarray.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
# TODO: Allow dynamic/unknown number of dimensions by supporting vector axes.
12
struct SymbolicArray{T, N, Name, Axes <: NTuple{N, AbstractUnitRange{<:Integer}}} <: AbstractArray{T, N}
23
name::Name
34
axes::Axes
@@ -26,6 +27,12 @@ end
2627
function Base.setindex!(a::SymbolicArray{<:Any, N}, value, I::Vararg{Int, N}) where {N}
2728
return error("Indexing into SymbolicArray not supported.")
2829
end
30+
using DerivableInterfaces: DerivableInterfaces
31+
DerivableInterfaces.permuteddims(a::SymbolicArray, p) = permutedims(a, p)
32+
function Base.permutedims(a::SymbolicArray, p)
33+
@assert ndims(a) == length(p) && isperm(p)
34+
return SymbolicArray(symname(a), ntuple(i -> axes(a)[p[i]], ndims(a)))
35+
end
2936
function Base.show(io::IO, mime::MIME"text/plain", a::SymbolicArray)
3037
Base.summary(io, a)
3138
println(io, ":")

src/LazyNamedDimsArrays/symbolicnameddimsarray.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
using NamedDimsArrays: NamedDimsArray, dename, inds
1+
using NamedDimsArrays: NamedDimsArray, dename, inds, nameddims
22

33
const SymbolicNamedDimsArray{T, N, Parent <: SymbolicArray{T, N}, DimNames} =
44
NamedDimsArray{T, N, Parent, DimNames}
55
function symnameddims(name, dims)
6-
return lazy(NamedDimsArray(SymbolicArray(name, dename.(dims)), dims))
6+
return lazy(nameddims(SymbolicArray(name, dename.(dims)), dims))
77
end
88
symnameddims(name) = symnameddims(name, ())
99
using AbstractTrees: AbstractTrees

src/contract_network.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ function contraction_order(tn; alg = default_kwargs(contraction_order, tn).alg,
4242
return contraction_order(to_algorithm(alg; kwargs...), tn)
4343
end
4444
function contraction_order(alg::Algorithm"left_associative", tn)
45-
return prod(i -> symnameddims(i, Tuple(inds(tn[i]))), eachindex(tn))
45+
return prod(i -> symnameddims(i, Tuple(inds(tn[i]))), keys(tn))
4646
end
4747
function contraction_order(alg::Algorithm, tn)
4848
s = contraction_order(tn; alg = Algorithm"left_associative"())

0 commit comments

Comments
 (0)