Skip to content
12 changes: 8 additions & 4 deletions src/array.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
"""
OneHotArray{T, N, M, I} <: AbstractArray{Bool, M}
OneHotArray(indices, L)
OneHotArray(indices, L, [axis=1])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we do add this, it should probably be a dims::Integer keyword on onehotbatch. IMO it's weird if a type constructor does not return the stated type.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand the sentiment, but it seems awkward to have to maintain a set of functions and increase their complexity to just have the same functionality..
If this is a complete no go then either the alternative implementation (which might have even more problems) or maybe add this functionality as a separate utility function?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is it more complex to alter the lower-case function than the upper-case type constructor?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought some more about this point and I agree changing onehot/onehotbatch/etc is the better approach.
I'll take a closer look at the functions (still not very familiar with them all).
Which would you say would be appropriate?


A one-hot `M`-dimensional array with `L` labels (i.e. `size(A, 1) == L` and `sum(A, dims=1) == 1`)
A one-hot `M`-dimensional array with `L` labels (i.e. `size(A, axis) == L` and `sum(A, dims=axis) == 1`)
stored as a compact `N == M-1`-dimensional array of indices.

Typically constructed by [`onehot`](@ref) and [`onehotbatch`](@ref).
Expand All @@ -15,6 +15,10 @@ end
OneHotArray{T, N, I}(indices, L::Int) where {T, N, I} = OneHotArray{T, N, N+1, I}(indices, L)
OneHotArray(indices::T, L::Int) where {T<:Integer} = OneHotArray{T, 0, 1, T}(indices, L)
OneHotArray(indices::I, L::Int) where {T, N, I<:AbstractArray{T, N}} = OneHotArray{T, N, N+1, I}(indices, L)
function OneHotArray(indices, L, axis::Int)
a = collect(1:length(size(indices))+1)
PermutedDimsArray(OneHotArray(indices, L), insert!(a, 1, popat!(a, axis)))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The permutation can be computed without mutating an array, something like this:

julia> let dims=2
       ntuple(d -> d==dims ? 1 : d<dims ? d+1 : d, 4)
       end
(2, 1, 3, 4)

julia> let dims=3
       ntuple(d -> d==dims ? 1 : d<dims ? d+1 : d, 4)
       end
(2, 3, 1, 4)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Notice your second example gives (2,3,1,4) which is not what I meant.
But the general suggestion is on point

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

julia> dims = 3
julia> ntuple(d -> (d==dims ? 1 : (d==1 ? dims : d)), 4)
(3, 2, 1, 4)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought I may be wrong by an invperm, but did not check carefully.

end

_indices(x::OneHotArray) = x.indices
_indices(x::Base.ReshapedArray{<:Any, <:Any, <:OneHotArray}) =
Expand Down Expand Up @@ -69,7 +73,7 @@ end
# the method above is faster on the CPU but will scalar index on the GPU
# so we define the method below to pass the extra indices directly to GPU array
function Base.getindex(x::OneHotArray{<:Any, N, <:Any, <:AbstractGPUArray},
i::Int,
i::Int,
I::Vararg{Any, N}) where N
@boundscheck (1 <= i <= x.nlabels) || throw(BoundsError(x, (i, I...)))
return x.indices[I...] .== i
Expand Down Expand Up @@ -154,5 +158,5 @@ Base.map(f, x::OneHotLike) = Base.broadcast(f, x)

Base.argmax(x::OneHotLike; dims = Colon()) =
(_isonehot(x) && dims == 1) ?
reshape(CartesianIndex.(_indices(x), CartesianIndices(_indices(x))), 1, size(_indices(x))...) :
reshape(CartesianIndex.(_indices(x), CartesianIndices(_indices(x))), 1, size(_indices(x))...) :
invoke(argmax, Tuple{AbstractArray}, x; dims = dims)
33 changes: 33 additions & 0 deletions test/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@ ov2 = OneHotVector(rand(1:11), 11)
om = OneHotMatrix(rand(1:10, 5), 10)
om2 = OneHotMatrix(rand(1:11, 5), 11)
oa = OneHotArray(rand(1:10, 5, 5), 10)
oa2 = OneHotArray(rand(1:10, 5, 5), 10, 2)

# sizes
@testset "Base.size" begin
@test size(ov) == (10,)
@test size(om) == (10, 5)
@test size(oa) == (10, 5, 5)
@test size(oa2) == (5, 10, 5)
end

@testset "Indexing" begin
Expand All @@ -32,18 +34,30 @@ end
@test oa[:, :, :] == oa
@test oa[:] == reshape(oa, :)

@test oa2[3, 3, 3] == (oa2.parent.indices[3, 3] == 3)
@test oa2[3, :, 3] == OneHotVector(oa2.parent.indices[3, 3], 10)
@test oa2[:, 3, 3] == (oa2.parent.indices[:, 3] .== 3)
@test oa2[:, 3, :] == (oa2.parent.indices .== 3)
@test oa2[3, :, :] == OneHotMatrix(oa2.parent.indices[3, :], 10)
@test oa2[:, :, :] == oa2
@test oa2[:] == reshape(oa2, :)

# cartesian indexing
@test oa[CartesianIndex(3, 3, 3)] == oa[3, 3, 3]
@test oa2[CartesianIndex(3, 3, 3)] == oa2[3, 3, 3]

# linear indexing
@test om[11] == om[1, 2]
@test oa[52] == oa[2, 1, 2]
@test oa2[55] == oa2[1, 2, 2]

# bounds checks
@test_throws BoundsError ov[0]
@test_throws BoundsError om[2, -1]
@test_throws BoundsError oa[11, 5, 5]
@test_throws BoundsError oa[:, :]
@test_throws BoundsError oa2[5, 11, 5]
@test_throws BoundsError oa2[:, :]
end

@testset "Concatenating" begin
Expand All @@ -64,6 +78,9 @@ end
@test cat(oa, oa; dims = 3) isa OneHotArray
@test cat(oa, oa; dims = 1) == cat(collect(oa), collect(oa); dims = 1)

@test cat(oa2, oa2; dims = 3) == OneHotArray(cat(oa2.parent.indices, oa2.parent.indices; dims = 2), 10, 2)
@test cat(oa2, oa2; dims = 2) == cat(collect(oa2), collect(oa2); dims = 2)

# stack
@test stack([ov, ov]) == hcat(ov, ov)
@test stack([ov, ov, ov]) isa OneHotMatrix
Expand Down Expand Up @@ -96,6 +113,18 @@ end
@test argmax(r) == argmax(OneHotMatrix(reshape(oa.indices, :), 10))
@test OneHotArrays._fast_argmax(r) == collect(reshape(oa.indices, :))
end

@testset "w/ cat" begin
r = reshape(oa2, 10, :)
@test vcat(r, r) isa Array{Bool}
end

@testset "w/ argmax" begin
oa2p = PermutedDimsArray(oa2, [2,1,3])
r = reshape(oa2p, 10, :)
@test argmax(r) == argmax(OneHotMatrix(reshape(oa2p.parent.parent.indices, :), 10))
@test stack(collect(Tuple.(OneHotArrays._fast_argmax(r))))[1,:] == collect(reshape(oa2p.parent.parent.indices, :))
end
end

@testset "Base.argmax" begin
Expand All @@ -106,9 +135,13 @@ end
@test argmax(om; dims = 2) == argmax(convert(Array{Bool}, om); dims = 2)
@test argmax(oa; dims = 1) == argmax(convert(Array{Bool}, oa); dims = 1)
@test argmax(oa; dims = 3) == argmax(convert(Array{Bool}, oa); dims = 3)
@test argmax(oa2; dims = 2) == argmax(convert(Array{Bool}, oa2); dims = 2)
@test argmax(oa2; dims = 3) == argmax(convert(Array{Bool}, oa2); dims = 3)
end

@testset "Forward map to broadcast" begin
@test map(identity, oa) == oa
@test map(x -> 2 * x, oa) == 2 .* oa
@test map(identity, oa2) == oa2
@test map(x -> 2 * x, oa2) == 2 .* oa2
end