Skip to content

Commit 9780b21

Browse files
committed
Check for one-hotness on fast paths instead of ReshapedArray constructor
1 parent 2835923 commit 9780b21

File tree

2 files changed

+32
-19
lines changed

2 files changed

+32
-19
lines changed

src/onehot.jl

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -22,16 +22,10 @@ OneHotMatrix(indices, L) = OneHotArray(indices, L)
2222
# e.g. argmax
2323
const OneHotLike{T, L, N, var"N+1", I} =
2424
Union{OneHotArray{T, L, N, var"N+1", I},
25-
Base.ReshapedArray{Bool, var"N+1", <:OneHotArray{T, L}}}
26-
27-
# when reshaping a OneHotArray and first(dims) != L
28-
# convert the parent array to Array{Bool}
29-
# so that the ReshapedArray does not hit fast paths
30-
function Base.ReshapedArray(parent::OneHotArray{<:Any, L}, dims::NTuple{N,Int}, mi) where {L, N}
31-
parent = (first(dims) != L) ? convert(_onehot_bool_type(parent), parent) : parent
32-
33-
Base.ReshapedArray{Bool,N,typeof(parent),typeof(mi)}(parent, dims, mi)
34-
end
25+
Base.ReshapedArray{Bool, var"N+1", <:OneHotArray{T, L, <:Any, <:Any, I}}}
26+
27+
_isonehot(x::OneHotArray) = true
28+
_isonehot(x::Base.ReshapedArray{<:Any, <:Any, <:OneHotArray{<:Any, L}}) where L = (size(x, 1) == L)
3529

3630
Base.size(x::OneHotArray{<:Any, L}) where L = (Int(L), size(x.indices)...)
3731

@@ -41,15 +35,15 @@ Base.getindex(x::OneHotVector, i::Integer) = _onehotindex(x.indices, i)
4135
Base.getindex(x::OneHotVector{T, L}, ::Colon) where {T, L} = x
4236

4337
Base.getindex(x::OneHotArray, i::Integer, I...) = _onehotindex.(x.indices[I...], i)
44-
Base.getindex(x::OneHotLike{<:Any, L}, ::Colon, I...) where L = OneHotArray(_indices(x)[I...], L)
38+
Base.getindex(x::OneHotArray{<:Any, L}, ::Colon, I...) where L = OneHotArray(x.indices[I...], L)
4539
Base.getindex(x::OneHotArray{<:Any, <:Any, <:Any, N}, ::Vararg{Colon, N}) where N = x
4640
Base.getindex(x::OneHotArray, I::CartesianIndex{N}) where N = x[I[1], Tuple(I)[2:N]...]
4741

4842
_onehot_bool_type(x::OneHotLike{<:Any, <:Any, <:Any, N, <:Union{Integer, AbstractArray}}) where N = Array{Bool, N}
4943
_onehot_bool_type(x::OneHotLike{<:Any, <:Any, <:Any, N, <:CuArray}) where N = CuArray{Bool, N}
5044

5145
function Base.cat(xs::OneHotLike{<:Any, L}...; dims::Int) where L
52-
if isone(dims)
46+
if isone(dims) || any(x -> !_isonehot(x), xs)
5347
return cat(map(x -> convert(_onehot_bool_type(x), x), xs)...; dims = dims)
5448
else
5549
return OneHotArray(cat(_indices.(xs)...; dims = dims - 1), L)
@@ -61,13 +55,14 @@ Base.vcat(xs::OneHotLike...) = cat(xs...; dims = 1)
6155

6256
batch(xs::AbstractArray{<:OneHotVector{<:Any, L}}) where L = OneHotArray(_indices.(xs), L)
6357

64-
Adapt.adapt_structure(T, x::OneHotLike{<:Any, L}) where L = OneHotArray(adapt(T, _indices(x)), L)
58+
Adapt.adapt_structure(T, x::OneHotArray{<:Any, L}) where L = OneHotArray(adapt(T, _indices(x)), L)
6559

66-
Base.BroadcastStyle(::Type{<:OneHotLike{<:Any, <:Any, <:Any, N, <:CuArray}}) where N = CUDA.CuArrayStyle{N}()
60+
Base.BroadcastStyle(::Type{<:OneHotArray{<:Any, <:Any, <:Any, N, <:CuArray}}) where N = CUDA.CuArrayStyle{N}()
6761

6862
Base.argmax(x::OneHotLike; dims = Colon()) =
69-
(dims == 1) ? reshape(CartesianIndex.(_indices(x), CartesianIndices(_indices(x))), 1, size(_indices(x))...) :
70-
argmax(convert(_onehot_bool_type(x), x); dims = dims)
63+
(_isonehot(x) && dims == 1) ?
64+
reshape(CartesianIndex.(_indices(x), CartesianIndices(_indices(x))), 1, size(_indices(x))...) :
65+
argmax(convert(_onehot_bool_type(x), x); dims = dims)
7166

7267
"""
7368
onehot(l, labels[, unk])
@@ -147,7 +142,13 @@ function onecold(y::AbstractArray, labels = 1:size(y, 1))
147142
end
148143

149144
_fast_argmax(x::AbstractArray) = dropdims(argmax(x; dims = 1); dims = 1)
150-
_fast_argmax(x::OneHotLike) = _indices(x)
145+
function _fast_argmax(x::OneHotLike)
146+
if _isonehot(x)
147+
return _indices(x)
148+
else
149+
return _fast_argmax(convert(_onehot_bool_type(x), x))
150+
end
151+
end
151152

152153
@nograd OneHotArray, onecold, onehot, onehotbatch
153154

test/onehot.jl

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,9 +92,21 @@ end
9292
@test reshape(oa, 10, 25) isa OneHotLike
9393
@test reshape(oa, 10, :) isa OneHotLike
9494
@test reshape(oa, :, 25) isa OneHotLike
95-
@test reshape(oa, 50, :) isa Base.ReshapedArray{<:Any, <:Any, <:Array}
96-
@test reshape(oa, 5, 10, 5) isa Base.ReshapedArray{<:Any, <:Any, <:Array}
95+
@test reshape(oa, 50, :) isa OneHotLike
96+
@test reshape(oa, 5, 10, 5) isa OneHotLike
9797
@test reshape(oa, (10, 25)) isa OneHotLike
98+
99+
@testset "w/ cat" begin
100+
r = reshape(oa, 10, :)
101+
@test hcat(r, r) isa OneHotArray
102+
@test vcat(r, r) isa Array{Bool}
103+
end
104+
105+
@testset "w/ argmax" begin
106+
r = reshape(oa, 10, :)
107+
@test argmax(r) == argmax(OneHotMatrix(reshape(oa.indices, :), 10))
108+
@test Flux._fast_argmax(r) == collect(reshape(oa.indices, :))
109+
end
98110
end
99111

100112
@testset "Base.argmax" begin

0 commit comments

Comments
 (0)