Skip to content

Commit cf0042c

Browse files
bors[bot]darsnack
andauthored
Merge #1459
1459: Use fallback for reshape/cat OneHotArray r=DhairyaLGandhi a=darsnack This falls back to reshaping a `Bool` array whenever reshaping the first dimension of a `OneHotArray`. @DhairyaLGandhi @CarloLucibello @simeonschaub ### PR Checklist - [x] Tests are added - [ ] ~~Entry in NEWS.md~~ - [x] Documentation, if applicable Co-authored-by: Kyle Daruwalla <[email protected]> Co-authored-by: Kyle Daruwalla <[email protected]>
2 parents b5e5741 + d27139f commit cf0042c

File tree

2 files changed

+54
-28
lines changed

2 files changed

+54
-28
lines changed

src/onehot.jl

Lines changed: 32 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,24 @@ OneHotArray(indices::T, L::Integer) where {T<:Integer} = OneHotArray{T, L, 0, T}
99
OneHotArray(indices::AbstractArray{T, N}, L::Integer) where {T, N} = OneHotArray{T, L, N, typeof(indices)}(indices)
1010

1111
_indices(x::OneHotArray) = x.indices
12+
_indices(x::Base.ReshapedArray{<:Any, <:Any, <:OneHotArray}) =
13+
reshape(parent(x).indices, x.dims[2:end])
1214

1315
const OneHotVector{T, L} = OneHotArray{T, L, 0, 1, T}
1416
const OneHotMatrix{T, L, I} = OneHotArray{T, L, 1, 2, I}
1517

1618
OneHotVector(idx, L) = OneHotArray(idx, L)
1719
OneHotMatrix(indices, L) = OneHotArray(indices, L)
1820

21+
# use this type so reshaped arrays hit fast paths
22+
# e.g. argmax
23+
const OneHotLike{T, L, N, var"N+1", I} =
24+
Union{OneHotArray{T, L, N, var"N+1", I},
25+
Base.ReshapedArray{Bool, var"N+1", <:OneHotArray{T, L, <:Any, <:Any, I}}}
26+
27+
_isonehot(x::OneHotArray) = true
28+
_isonehot(x::Base.ReshapedArray{<:Any, <:Any, <:OneHotArray{<:Any, L}}) where L = (size(x, 1) == L)
29+
1930
Base.size(x::OneHotArray{<:Any, L}) where L = (Int(L), size(x.indices)...)
2031

2132
_onehotindex(x, i) = (x == i)
@@ -28,34 +39,30 @@ Base.getindex(x::OneHotArray{<:Any, L}, ::Colon, I...) where L = OneHotArray(x.i
2839
Base.getindex(x::OneHotArray{<:Any, <:Any, <:Any, N}, ::Vararg{Colon, N}) where N = x
2940
Base.getindex(x::OneHotArray, I::CartesianIndex{N}) where N = x[I[1], Tuple(I)[2:N]...]
3041

31-
_onehot_bool_type(x::OneHotArray{<:Any, <:Any, <:Any, N, <:Union{Integer, AbstractArray}}) where N = Array{Bool, N}
32-
_onehot_bool_type(x::OneHotArray{<:Any, <:Any, <:Any, N, <:CuArray}) where N = CuArray{Bool, N}
42+
_onehot_bool_type(x::OneHotLike{<:Any, <:Any, <:Any, N, <:Union{Integer, AbstractArray}}) where N = Array{Bool, N}
43+
_onehot_bool_type(x::OneHotLike{<:Any, <:Any, <:Any, N, <:CuArray}) where N = CuArray{Bool, N}
3344

34-
function Base.cat(xs::OneHotArray{<:Any, L}...; dims::Int) where L
35-
if isone(dims)
36-
return throw(ArgumentError("Cannot concat OneHotArray along first dimension. Use collect to convert to Bool array first."))
45+
function Base.cat(xs::OneHotLike{<:Any, L}...; dims::Int) where L
46+
if isone(dims) || any(x -> !_isonehot(x), xs)
47+
return cat(map(x -> convert(_onehot_bool_type(x), x), xs)...; dims = dims)
3748
else
3849
return OneHotArray(cat(_indices.(xs)...; dims = dims - 1), L)
3950
end
4051
end
4152

42-
Base.hcat(xs::OneHotArray...) = cat(xs...; dims = 2)
43-
Base.vcat(xs::OneHotArray...) = cat(xs...; dims = 1)
44-
45-
Base.reshape(x::OneHotArray{<:Any, L}, dims::Dims) where L =
46-
(first(dims) == L) ? OneHotArray(reshape(x.indices, dims[2:end]...), L) :
47-
throw(ArgumentError("Cannot reshape OneHotArray if first(dims) != size(x, 1)"))
48-
Base._reshape(x::OneHotArray, dims::Tuple{Vararg{Int}}) = reshape(x, dims)
53+
Base.hcat(xs::OneHotLike...) = cat(xs...; dims = 2)
54+
Base.vcat(xs::OneHotLike...) = cat(xs...; dims = 1)
4955

5056
batch(xs::AbstractArray{<:OneHotVector{<:Any, L}}) where L = OneHotArray(_indices.(xs), L)
5157

52-
Adapt.adapt_structure(T, x::OneHotArray{<:Any, L}) where L = OneHotArray(adapt(T, x.indices), L)
58+
Adapt.adapt_structure(T, x::OneHotArray{<:Any, L}) where L = OneHotArray(adapt(T, _indices(x)), L)
5359

5460
Base.BroadcastStyle(::Type{<:OneHotArray{<:Any, <:Any, <:Any, N, <:CuArray}}) where N = CUDA.CuArrayStyle{N}()
5561

56-
Base.argmax(x::OneHotArray; dims = Colon()) =
57-
(dims == 1) ? reshape(CartesianIndex.(x.indices, CartesianIndices(x.indices)), 1, size(x.indices)...) :
58-
argmax(convert(_onehot_bool_type(x), x); dims = dims)
62+
Base.argmax(x::OneHotLike; dims = Colon()) =
63+
(_isonehot(x) && dims == 1) ?
64+
reshape(CartesianIndex.(_indices(x), CartesianIndices(_indices(x))), 1, size(_indices(x))...) :
65+
invoke(argmax, Tuple{AbstractArray}, x; dims = dims)
5966

6067
"""
6168
onehot(l, labels[, unk])
@@ -135,11 +142,18 @@ function onecold(y::AbstractArray, labels = 1:size(y, 1))
135142
end
136143

137144
_fast_argmax(x::AbstractArray) = dropdims(argmax(x; dims = 1); dims = 1)
138-
_fast_argmax(x::OneHotArray) = x.indices
145+
function _fast_argmax(x::OneHotLike)
146+
if _isonehot(x)
147+
return _indices(x)
148+
else
149+
return _fast_argmax(convert(_onehot_bool_type(x), x))
150+
end
151+
end
139152

140153
@nograd OneHotArray, onecold, onehot, onehotbatch
141154

142-
function Base.:(*)(A::AbstractMatrix, B::OneHotArray{<:Any, L}) where L
155+
function Base.:(*)(A::AbstractMatrix, B::OneHotLike{<:Any, L}) where L
156+
_isonehot(B) || return invoke(*, Tuple{AbstractMatrix, AbstractMatrix}, A, B)
143157
size(A, 2) == L || throw(DimensionMismatch("Matrix column must correspond with OneHot size: $(size(A, 2)) != $L"))
144158
return A[:, onecold(B)]
145159
end

test/onehot.jl

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ end
3535
end
3636

3737
@testset "OneHotArray" begin
38-
using Flux: OneHotArray, OneHotVector, OneHotMatrix
38+
using Flux: OneHotArray, OneHotVector, OneHotMatrix, OneHotLike
3939

4040
ov = OneHotVector(rand(1:10), 10)
4141
om = OneHotMatrix(rand(1:10, 5), 10)
@@ -74,27 +74,39 @@ end
7474
@testset "Concatenating" begin
7575
# vector cat
7676
@test hcat(ov, ov) == OneHotMatrix(vcat(ov.indices, ov.indices), 10)
77-
@test_throws ArgumentError vcat(ov, ov)
77+
@test vcat(ov, ov) == vcat(collect(ov), collect(ov))
7878
@test cat(ov, ov; dims = 3) == OneHotArray(cat(ov.indices, ov.indices; dims = 2), 10)
7979

8080
# matrix cat
8181
@test hcat(om, om) == OneHotMatrix(vcat(om.indices, om.indices), 10)
82-
@test_throws ArgumentError vcat(om, om)
82+
@test vcat(om, om) == vcat(collect(om), collect(om))
8383
@test cat(om, om; dims = 3) == OneHotArray(cat(om.indices, om.indices; dims = 2), 10)
8484

8585
# array cat
8686
@test cat(oa, oa; dims = 3) == OneHotArray(cat(oa.indices, oa.indices; dims = 2), 10)
87-
@test_throws ArgumentError cat(oa, oa; dims = 1)
87+
@test cat(oa, oa; dims = 1) == cat(collect(oa), collect(oa); dims = 1)
8888
end
8989

9090
@testset "Base.reshape" begin
9191
# reshape test
92-
@test reshape(oa, 10, 25) isa OneHotArray
93-
@test reshape(oa, 10, :) isa OneHotArray
94-
@test reshape(oa, :, 25) isa OneHotArray
95-
@test_throws ArgumentError reshape(oa, 50, :)
96-
@test_throws ArgumentError reshape(oa, 5, 10, 5)
97-
@test reshape(oa, (10, 25)) isa OneHotArray
92+
@test reshape(oa, 10, 25) isa OneHotLike
93+
@test reshape(oa, 10, :) isa OneHotLike
94+
@test reshape(oa, :, 25) isa OneHotLike
95+
@test reshape(oa, 50, :) isa OneHotLike
96+
@test reshape(oa, 5, 10, 5) isa OneHotLike
97+
@test reshape(oa, (10, 25)) isa OneHotLike
98+
99+
@testset "w/ cat" begin
100+
r = reshape(oa, 10, :)
101+
@test hcat(r, r) isa OneHotArray
102+
@test vcat(r, r) isa Array{Bool}
103+
end
104+
105+
@testset "w/ argmax" begin
106+
r = reshape(oa, 10, :)
107+
@test argmax(r) == argmax(OneHotMatrix(reshape(oa.indices, :), 10))
108+
@test Flux._fast_argmax(r) == collect(reshape(oa.indices, :))
109+
end
98110
end
99111

100112
@testset "Base.argmax" begin

0 commit comments

Comments
 (0)