Skip to content

Commit 8031650

Browse files
authored
Fix incorrect perm, generalize some constructors (#24)
1 parent 920f4c3 commit 8031650

File tree

3 files changed

+10
-4
lines changed

3 files changed

+10
-4
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "NamedDimsArrays"
22
uuid = "60cbd0c0-df58-4cb7-918c-6f5607b73fde"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.3.8"
4+
version = "0.3.9"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/abstractnameddimsarray.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,9 @@ const NamedDimsAxis = AbstractNamedUnitRange{
152152

153153
# Generic constructor.
154154
function nameddims(a::AbstractArray, nameddimsindices)
155+
if iszero(ndims(a))
156+
return constructorof_nameddims(typeof(a))(a, nameddimsindices)
157+
end
155158
# TODO: Check the shape of `nameddimsindices` matches the shape of `a`.
156159
arrtype = mapreduce(nameddimsarraytype, combine_nameddimsarraytype, nameddimsindices)
157160
return arrtype(a, to_nameddimsindices(a, nameddimsindices))
@@ -781,7 +784,7 @@ function set_promote_shape(
781784
ax1::Tuple{AbstractNamedUnitRange,Vararg{AbstractNamedUnitRange,N}},
782785
ax2::Tuple{AbstractNamedUnitRange,Vararg{AbstractNamedUnitRange,N}},
783786
) where {N}
784-
perm = getperm(ax1, ax2)
787+
perm = getperm(ax2, ax1)
785788
ax2_aligned = map(i -> ax2[i], perm)
786789
ax_promoted = promote_shape(dename.(ax1), dename.(ax2_aligned))
787790
return named.(ax_promoted, name.(ax1))
@@ -813,7 +816,7 @@ function set_check_broadcast_shape(
813816
ax1::Tuple{AbstractNamedUnitRange,Vararg{AbstractNamedUnitRange,N}},
814817
ax2::Tuple{AbstractNamedUnitRange,Vararg{AbstractNamedUnitRange,N}},
815818
) where {N}
816-
perm = getperm(ax1, ax2)
819+
perm = getperm(ax2, ax1)
817820
ax2_aligned = map(i -> ax2[i], perm)
818821
check_broadcast_shape(dename.(ax1), dename.(ax2_aligned))
819822
return nothing

src/tensoralgebra.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,10 @@ function TensorAlgebra.contract(
2929
a_dest, nameddimsindices_dest = contract(
3030
dename(a1), nameddimsindices(a1), dename(a2), nameddimsindices(a2), α
3131
)
32-
return nameddims(a_dest, nameddimsindices_dest)
32+
nameddimstype = combine_nameddimsarraytype(
33+
constructorof(typeof(a1)), constructorof(typeof(a2))
34+
)
35+
return nameddimstype(a_dest, nameddimsindices_dest)
3336
end
3437

3538
function Base.:*(a1::AbstractNamedDimsArray, a2::AbstractNamedDimsArray)

0 commit comments

Comments
 (0)