Skip to content

Commit d0932f0

Browse files
authored
Make size and axes output sets (#12)
1 parent 6c2e1b2 commit d0932f0

File tree

5 files changed

+68
-27
lines changed

5 files changed

+68
-27
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "NamedDimsArrays"
22
uuid = "60cbd0c0-df58-4cb7-918c-6f5607b73fde"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.3.1"
4+
version = "0.3.2"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

README.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,10 @@ a2 = randn(j, k)
4747

4848
@test dimnames(a1) == ("i", "j")
4949
@test nameddimsindices(a1) == (i, j)
50-
@test axes(a1) == (named(1:2, i), named(1:2, j))
51-
@test size(a1) == (named(2, i), named(2, j))
50+
@test axes(a1, 1) == named(1:2, i)
51+
@test axes(a1, 2) == named(1:2, j)
52+
@test size(a1, 1) == named(2, i)
53+
@test size(a1, 2) == named(2, j)
5254

5355
# Indexing
5456
@test a1[j => 2, i => 1] == a1[1, 2]

examples/README.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,10 @@ a2 = randn(j, k)
5252

5353
@test dimnames(a1) == ("i", "j")
5454
@test nameddimsindices(a1) == (i, j)
55-
@test axes(a1) == (named(1:2, i), named(1:2, j))
56-
@test size(a1) == (named(2, i), named(2, j))
55+
@test axes(a1, 1) == named(1:2, i)
56+
@test axes(a1, 2) == named(1:2, j)
57+
@test size(a1, 1) == named(2, i)
58+
@test size(a1, 2) == named(2, j)
5759

5860
## Indexing
5961
@test a1[j => 2, i => 1] == a1[1, 2]

src/abstractnameddimsarray.jl

Lines changed: 57 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -169,8 +169,41 @@ function combine_nameddimsarraytype(
169169
end
170170
combine_nameddimsarraytype(::Type{T}, ::Type{T}) where {T<:AbstractNamedDimsArray} = T
171171

172-
Base.axes(a::AbstractNamedDimsArray) = map(named, axes(dename(a)), nameddimsindices(a))
173-
Base.size(a::AbstractNamedDimsArray) = map(named, size(dename(a)), nameddimsindices(a))
172+
using Base.Broadcast: Broadcasted, Style
173+
174+
struct NaiveOrderedSet{Values}
175+
values::Values
176+
end
177+
Base.Tuple(s::NaiveOrderedSet) = s.values
178+
Base.length(s::NaiveOrderedSet) = length(Tuple(s))
179+
Base.axes(s::NaiveOrderedSet) = axes(Tuple(s))
180+
Base.:(==)(s1::NaiveOrderedSet, s2::NaiveOrderedSet) = issetequal(Tuple(s1), Tuple(s2))
181+
Base.iterate(s::NaiveOrderedSet, args...) = iterate(Tuple(s), args...)
182+
Base.getindex(s::NaiveOrderedSet, I::Int) = Tuple(s)[I]
183+
Base.invperm(s::NaiveOrderedSet) = NaiveOrderedSet(invperm(Tuple(s)))
184+
Base.Broadcast._axes(::Broadcasted, axes::NaiveOrderedSet) = axes
185+
Base.Broadcast.BroadcastStyle(::Type{<:NaiveOrderedSet}) = Style{NaiveOrderedSet}()
186+
Base.Broadcast.broadcastable(s::NaiveOrderedSet) = s
187+
188+
function Base.copy(
189+
bc::Broadcasted{Style{NaiveOrderedSet},<:Any,<:Any,<:Tuple{<:NaiveOrderedSet}}
190+
)
191+
return NaiveOrderedSet(bc.f.(Tuple(only(bc.args))))
192+
end
193+
# Multiple arguments not supported.
194+
function Base.copy(bc::Broadcasted{Style{NaiveOrderedSet}})
195+
return error("This broadcasting expression of `NaiveOrderedSet` is not supported.")
196+
end
197+
function Base.map(f::Function, s::NaiveOrderedSet)
198+
return NaiveOrderedSet(map(f, Tuple(s)))
199+
end
200+
201+
function Base.axes(a::AbstractNamedDimsArray)
202+
return NaiveOrderedSet(map(named, axes(dename(a)), nameddimsindices(a)))
203+
end
204+
function Base.size(a::AbstractNamedDimsArray)
205+
return NaiveOrderedSet(map(named, size(dename(a)), nameddimsindices(a)))
206+
end
174207

175208
# Circumvent issue when ndims isn't known at compile time.
176209
function Base.axes(a::AbstractNamedDimsArray, d)
@@ -267,6 +300,9 @@ struct NamedDimsCartesianIndices{
267300
return new{length(indices),typeof(indices),Tuple{eltype.(indices)...}}(indices)
268301
end
269302
end
303+
function NamedDimsCartesianIndices(indices::NaiveOrderedSet)
304+
return NamedDimsCartesianIndices(Tuple(indices))
305+
end
270306

271307
Base.eltype(I::NamedDimsCartesianIndices) = eltype(typeof(I))
272308
Base.axes(I::NamedDimsCartesianIndices) = map(only axes, I.indices)
@@ -672,20 +708,25 @@ end
672708
Broadcast.combine_axes(a::AbstractNamedDimsArray) = axes(a)
673709

674710
function Broadcast.broadcast_shape(
675-
ax1::Tuple{Vararg{AbstractNamedUnitRange}},
676-
ax2::Tuple{Vararg{AbstractNamedUnitRange}},
677-
ax_rest::Tuple{Vararg{AbstractNamedUnitRange}}...,
711+
ax1::NaiveOrderedSet, ax2::NaiveOrderedSet, ax_rest::NaiveOrderedSet...
678712
)
679713
return broadcast_shape(broadcast_shape(ax1, ax2), ax_rest...)
680714
end
681715

682-
function Broadcast.broadcast_shape(
683-
ax1::Tuple{Vararg{AbstractNamedUnitRange}}, ax2::Tuple{Vararg{AbstractNamedUnitRange}}
684-
)
716+
function Broadcast.broadcast_shape(ax1::NaiveOrderedSet, ax2::NaiveOrderedSet)
685717
return promote_shape(ax1, ax2)
686718
end
687719

688-
function Base.promote_shape(
720+
# Handle scalar values.
721+
function Broadcast.broadcast_shape(ax1::Tuple{}, ax2::NaiveOrderedSet)
722+
return ax2
723+
end
724+
725+
function Base.promote_shape(ax1::NaiveOrderedSet, ax2::NaiveOrderedSet)
726+
return NaiveOrderedSet(set_promote_shape(Tuple(ax1), Tuple(ax2)))
727+
end
728+
729+
function set_promote_shape(
689730
ax1::Tuple{AbstractNamedUnitRange,Vararg{AbstractNamedUnitRange,N}},
690731
ax2::Tuple{AbstractNamedUnitRange,Vararg{AbstractNamedUnitRange,N}},
691732
) where {N}
@@ -695,8 +736,11 @@ function Base.promote_shape(
695736
return named.(ax_promoted, name.(ax1))
696737
end
697738

698-
# Avoid comparison of `NamedInteger` against `1`.
699-
function Broadcast.check_broadcast_shape(
739+
function Broadcast.check_broadcast_shape(ax1::NaiveOrderedSet, ax2::NaiveOrderedSet)
740+
return set_check_broadcast_shape(Tuple(ax1), Tuple(ax2))
741+
end
742+
743+
function set_check_broadcast_shape(
700744
ax1::Tuple{AbstractNamedUnitRange,Vararg{AbstractNamedUnitRange,N}},
701745
ax2::Tuple{AbstractNamedUnitRange,Vararg{AbstractNamedUnitRange,N}},
702746
) where {N}
@@ -706,24 +750,17 @@ function Broadcast.check_broadcast_shape(
706750
return nothing
707751
end
708752

709-
# Handle scalars.
710-
function Base.promote_shape(
711-
ax1::Tuple{AbstractNamedUnitRange,Vararg{AbstractNamedUnitRange}}, ax2::Tuple{}
712-
)
713-
return ax1
714-
end
715-
716753
# Dename and lazily permute the arguments using the reference
717754
# dimension names.
718755
# TODO: Make a version that gets the nameddimsindices from `m`.
719756
function denamed(m::Mapped, nameddimsindices)
720757
return mapped(m.f, map(arg -> denamed(arg, nameddimsindices), m.args)...)
721758
end
722759

723-
function Base.similar(bc::Broadcasted{<:AbstractNamedDimsArrayStyle}, elt::Type, ax::Tuple)
760+
function Base.similar(bc::Broadcasted{<:AbstractNamedDimsArrayStyle}, elt::Type, ax)
724761
nameddimsindices = name.(ax)
725762
m′ = denamed(Mapped(bc), nameddimsindices)
726-
return nameddims(similar(m′, elt, dename.(ax)), nameddimsindices)
763+
return nameddims(similar(m′, elt, dename.(Tuple(ax))), nameddimsindices)
727764
end
728765

729766
function Base.copyto!(

test/basics/test_basics.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -141,9 +141,9 @@ using Test: @test, @test_throws, @testset
141141
nb = setnameddimsindices(na, named(3, "i") => named(3, "k"))
142142
na[1, 1] = 11
143143
@test na[1, 1] == 11
144-
@test size(na) == (named(3, named(1:3, "i")), named(4, named(1:4, "j")))
144+
@test Tuple(size(na)) == (named(3, named(1:3, "i")), named(4, named(1:4, "j")))
145145
@test length(na) == named(12, fusednames(named(1:3, "i"), named(1:4, "j")))
146-
@test axes(na) == (named(1:3, named(1:3, "i")), named(1:4, named(1:4, "j")))
146+
@test Tuple(axes(na)) == (named(1:3, named(1:3, "i")), named(1:4, named(1:4, "j")))
147147
@test randn(named.((3, 4), ("i", "j"))) isa NamedDimsArray
148148
@test na["i" => 1, "j" => 2] == a[1, 2]
149149
@test na["j" => 2, "i" => 1] == a[1, 2]

0 commit comments

Comments
 (0)