Skip to content

Commit d27139f

Browse files
Apply suggestions from code review
Co-authored-by: Simeon Schaub <[email protected]>
1 parent 9780b21 commit d27139f

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

src/onehot.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ Base.BroadcastStyle(::Type{<:OneHotArray{<:Any, <:Any, <:Any, N, <:CuArray}}) wh
6262
Base.argmax(x::OneHotLike; dims = Colon()) =
6363
(_isonehot(x) && dims == 1) ?
6464
reshape(CartesianIndex.(_indices(x), CartesianIndices(_indices(x))), 1, size(_indices(x))...) :
65-
argmax(convert(_onehot_bool_type(x), x); dims = dims)
65+
invoke(argmax, Tuple{AbstractArray}, x; dims = dims)
6666

6767
"""
6868
onehot(l, labels[, unk])
@@ -153,6 +153,7 @@ end
153153
@nograd OneHotArray, onecold, onehot, onehotbatch
154154

155155
function Base.:(*)(A::AbstractMatrix, B::OneHotLike{<:Any, L}) where L
156+
_isonehot(B) || return invoke(*, Tuple{AbstractMatrix, AbstractMatrix}, A, B)
156157
size(A, 2) == L || throw(DimensionMismatch("Matrix column must correspond with OneHot size: $(size(A, 2)) != $L"))
157158
return A[:, onecold(B)]
158159
end

0 commit comments

Comments
 (0)