@@ -9,13 +9,24 @@ OneHotArray(indices::T, L::Integer) where {T<:Integer} = OneHotArray{T, L, 0, T}
9
9
OneHotArray (indices:: AbstractArray{T, N} , L:: Integer ) where {T, N} = OneHotArray {T, L, N, typeof(indices)} (indices)
10
10
11
11
_indices (x:: OneHotArray ) = x. indices
12
+ _indices (x:: Base.ReshapedArray{<:Any, <:Any, <:OneHotArray} ) =
13
+ reshape (parent (x). indices, x. dims[2 : end ])
12
14
13
15
const OneHotVector{T, L} = OneHotArray{T, L, 0 , 1 , T}
14
16
const OneHotMatrix{T, L, I} = OneHotArray{T, L, 1 , 2 , I}
15
17
16
18
OneHotVector (idx, L) = OneHotArray (idx, L)
17
19
OneHotMatrix (indices, L) = OneHotArray (indices, L)
18
20
21
+ # use this type so reshaped arrays hit fast paths
22
+ # e.g. argmax
23
+ const OneHotLike{T, L, N, var"N+1" , I} =
24
+ Union{OneHotArray{T, L, N, var"N+1" , I},
25
+ Base. ReshapedArray{Bool, var"N+1" , <: OneHotArray{T, L, <:Any, <:Any, I} }}
26
+
27
+ _isonehot (x:: OneHotArray ) = true
28
+ _isonehot (x:: Base.ReshapedArray{<:Any, <:Any, <:OneHotArray{<:Any, L}} ) where L = (size (x, 1 ) == L)
29
+
19
30
Base. size (x:: OneHotArray{<:Any, L} ) where L = (Int (L), size (x. indices)... )
20
31
21
32
_onehotindex (x, i) = (x == i)
@@ -28,34 +39,30 @@ Base.getindex(x::OneHotArray{<:Any, L}, ::Colon, I...) where L = OneHotArray(x.i
28
39
Base. getindex (x:: OneHotArray{<:Any, <:Any, <:Any, N} , :: Vararg{Colon, N} ) where N = x
29
40
Base. getindex (x:: OneHotArray , I:: CartesianIndex{N} ) where N = x[I[1 ], Tuple (I)[2 : N]. .. ]
30
41
31
- _onehot_bool_type (x:: OneHotArray {<:Any, <:Any, <:Any, N, <:Union{Integer, AbstractArray}} ) where N = Array{Bool, N}
32
- _onehot_bool_type (x:: OneHotArray {<:Any, <:Any, <:Any, N, <:CuArray} ) where N = CuArray{Bool, N}
42
+ _onehot_bool_type (x:: OneHotLike {<:Any, <:Any, <:Any, N, <:Union{Integer, AbstractArray}} ) where N = Array{Bool, N}
43
+ _onehot_bool_type (x:: OneHotLike {<:Any, <:Any, <:Any, N, <:CuArray} ) where N = CuArray{Bool, N}
33
44
34
- function Base. cat (xs:: OneHotArray {<:Any, L} ...; dims:: Int ) where L
35
- if isone (dims)
36
- return throw ( ArgumentError ( " Cannot concat OneHotArray along first dimension. Use collect to convert to Bool array first. " ) )
45
+ function Base. cat (xs:: OneHotLike {<:Any, L} ...; dims:: Int ) where L
46
+ if isone (dims) || any (x -> ! _isonehot (x), xs)
47
+ return cat ( map (x -> convert ( _onehot_bool_type (x), x), xs) ... ; dims = dims )
37
48
else
38
49
return OneHotArray (cat (_indices .(xs)... ; dims = dims - 1 ), L)
39
50
end
40
51
end
41
52
42
- Base. hcat (xs:: OneHotArray... ) = cat (xs... ; dims = 2 )
43
- Base. vcat (xs:: OneHotArray... ) = cat (xs... ; dims = 1 )
44
-
45
- Base. reshape (x:: OneHotArray{<:Any, L} , dims:: Dims ) where L =
46
- (first (dims) == L) ? OneHotArray (reshape (x. indices, dims[2 : end ]. .. ), L) :
47
- throw (ArgumentError (" Cannot reshape OneHotArray if first(dims) != size(x, 1)" ))
48
- Base. _reshape (x:: OneHotArray , dims:: Tuple{Vararg{Int}} ) = reshape (x, dims)
53
+ Base. hcat (xs:: OneHotLike... ) = cat (xs... ; dims = 2 )
54
+ Base. vcat (xs:: OneHotLike... ) = cat (xs... ; dims = 1 )
49
55
50
56
batch (xs:: AbstractArray{<:OneHotVector{<:Any, L}} ) where L = OneHotArray (_indices .(xs), L)
51
57
52
- Adapt. adapt_structure (T, x:: OneHotArray{<:Any, L} ) where L = OneHotArray (adapt (T, x . indices ), L)
58
+ Adapt. adapt_structure (T, x:: OneHotArray{<:Any, L} ) where L = OneHotArray (adapt (T, _indices (x) ), L)
53
59
54
60
Base. BroadcastStyle (:: Type{<:OneHotArray{<:Any, <:Any, <:Any, N, <:CuArray}} ) where N = CUDA. CuArrayStyle {N} ()
55
61
56
- Base. argmax (x:: OneHotArray ; dims = Colon ()) =
57
- (dims == 1 ) ? reshape (CartesianIndex .(x. indices, CartesianIndices (x. indices)), 1 , size (x. indices)... ) :
58
- argmax (convert (_onehot_bool_type (x), x); dims = dims)
62
+ Base. argmax (x:: OneHotLike ; dims = Colon ()) =
63
+ (_isonehot (x) && dims == 1 ) ?
64
+ reshape (CartesianIndex .(_indices (x), CartesianIndices (_indices (x))), 1 , size (_indices (x))... ) :
65
+ invoke (argmax, Tuple{AbstractArray}, x; dims = dims)
59
66
60
67
"""
61
68
onehot(l, labels[, unk])
@@ -135,11 +142,18 @@ function onecold(y::AbstractArray, labels = 1:size(y, 1))
135
142
end
136
143
137
144
_fast_argmax (x:: AbstractArray ) = dropdims (argmax (x; dims = 1 ); dims = 1 )
138
- _fast_argmax (x:: OneHotArray ) = x. indices
145
+ function _fast_argmax (x:: OneHotLike )
146
+ if _isonehot (x)
147
+ return _indices (x)
148
+ else
149
+ return _fast_argmax (convert (_onehot_bool_type (x), x))
150
+ end
151
+ end
139
152
140
153
@nograd OneHotArray, onecold, onehot, onehotbatch
141
154
142
- function Base.:(* )(A:: AbstractMatrix , B:: OneHotArray{<:Any, L} ) where L
155
+ function Base.:(* )(A:: AbstractMatrix , B:: OneHotLike{<:Any, L} ) where L
156
+ _isonehot (B) || return invoke (* , Tuple{AbstractMatrix, AbstractMatrix}, A, B)
143
157
size (A, 2 ) == L || throw (DimensionMismatch (" Matrix column must correspond with OneHot size: $(size (A, 2 )) != $L " ))
144
158
return A[:, onecold (B)]
145
159
end
0 commit comments