@@ -22,16 +22,10 @@ OneHotMatrix(indices, L) = OneHotArray(indices, L)
22
22
# e.g. argmax
23
23
const OneHotLike{T, L, N, var"N+1" , I} =
24
24
Union{OneHotArray{T, L, N, var"N+1" , I},
25
- Base. ReshapedArray{Bool, var"N+1" , <: OneHotArray{T, L} }}
26
-
27
- # when reshaping a OneHotArray and first(dims) != L
28
- # convert the parent array to Array{Bool}
29
- # so that the ReshapedArray does not hit fast paths
30
- function Base. ReshapedArray (parent:: OneHotArray{<:Any, L} , dims:: NTuple{N,Int} , mi) where {L, N}
31
- parent = (first (dims) != L) ? convert (_onehot_bool_type (parent), parent) : parent
32
-
33
- Base. ReshapedArray {Bool,N,typeof(parent),typeof(mi)} (parent, dims, mi)
34
- end
25
+ Base. ReshapedArray{Bool, var"N+1" , <: OneHotArray{T, L, <:Any, <:Any, I} }}
26
+
27
+ _isonehot (x:: OneHotArray ) = true
28
+ _isonehot (x:: Base.ReshapedArray{<:Any, <:Any, <:OneHotArray{<:Any, L}} ) where L = (size (x, 1 ) == L)
35
29
36
30
Base. size (x:: OneHotArray{<:Any, L} ) where L = (Int (L), size (x. indices)... )
37
31
@@ -41,15 +35,15 @@ Base.getindex(x::OneHotVector, i::Integer) = _onehotindex(x.indices, i)
41
35
Base. getindex (x:: OneHotVector{T, L} , :: Colon ) where {T, L} = x
42
36
43
37
Base. getindex (x:: OneHotArray , i:: Integer , I... ) = _onehotindex .(x. indices[I... ], i)
44
- Base. getindex (x:: OneHotLike {<:Any, L} , :: Colon , I... ) where L = OneHotArray (_indices (x) [I... ], L)
38
+ Base. getindex (x:: OneHotArray {<:Any, L} , :: Colon , I... ) where L = OneHotArray (x . indices [I... ], L)
45
39
Base. getindex (x:: OneHotArray{<:Any, <:Any, <:Any, N} , :: Vararg{Colon, N} ) where N = x
46
40
Base. getindex (x:: OneHotArray , I:: CartesianIndex{N} ) where N = x[I[1 ], Tuple (I)[2 : N]. .. ]
47
41
48
42
_onehot_bool_type (x:: OneHotLike{<:Any, <:Any, <:Any, N, <:Union{Integer, AbstractArray}} ) where N = Array{Bool, N}
49
43
_onehot_bool_type (x:: OneHotLike{<:Any, <:Any, <:Any, N, <:CuArray} ) where N = CuArray{Bool, N}
50
44
51
45
function Base. cat (xs:: OneHotLike{<:Any, L} ...; dims:: Int ) where L
52
- if isone (dims)
46
+ if isone (dims) || any (x -> ! _isonehot (x), xs)
53
47
return cat (map (x -> convert (_onehot_bool_type (x), x), xs)... ; dims = dims)
54
48
else
55
49
return OneHotArray (cat (_indices .(xs)... ; dims = dims - 1 ), L)
@@ -61,13 +55,14 @@ Base.vcat(xs::OneHotLike...) = cat(xs...; dims = 1)
61
55
62
56
batch (xs:: AbstractArray{<:OneHotVector{<:Any, L}} ) where L = OneHotArray (_indices .(xs), L)
63
57
64
- Adapt. adapt_structure (T, x:: OneHotLike {<:Any, L} ) where L = OneHotArray (adapt (T, _indices (x)), L)
58
+ Adapt. adapt_structure (T, x:: OneHotArray {<:Any, L} ) where L = OneHotArray (adapt (T, _indices (x)), L)
65
59
66
- Base. BroadcastStyle (:: Type{<:OneHotLike {<:Any, <:Any, <:Any, N, <:CuArray}} ) where N = CUDA. CuArrayStyle {N} ()
60
+ Base. BroadcastStyle (:: Type{<:OneHotArray {<:Any, <:Any, <:Any, N, <:CuArray}} ) where N = CUDA. CuArrayStyle {N} ()
67
61
68
62
Base. argmax (x:: OneHotLike ; dims = Colon ()) =
69
- (dims == 1 ) ? reshape (CartesianIndex .(_indices (x), CartesianIndices (_indices (x))), 1 , size (_indices (x))... ) :
70
- argmax (convert (_onehot_bool_type (x), x); dims = dims)
63
+ (_isonehot (x) && dims == 1 ) ?
64
+ reshape (CartesianIndex .(_indices (x), CartesianIndices (_indices (x))), 1 , size (_indices (x))... ) :
65
+ argmax (convert (_onehot_bool_type (x), x); dims = dims)
71
66
72
67
"""
73
68
onehot(l, labels[, unk])
@@ -147,7 +142,13 @@ function onecold(y::AbstractArray, labels = 1:size(y, 1))
147
142
end
148
143
149
144
_fast_argmax (x:: AbstractArray ) = dropdims (argmax (x; dims = 1 ); dims = 1 )
150
- _fast_argmax (x:: OneHotLike ) = _indices (x)
145
+ function _fast_argmax (x:: OneHotLike )
146
+ if _isonehot (x)
147
+ return _indices (x)
148
+ else
149
+ return _fast_argmax (convert (_onehot_bool_type (x), x))
150
+ end
151
+ end
151
152
152
153
@nograd OneHotArray, onecold, onehot, onehotbatch
153
154
0 commit comments