Skip to content

Commit ba04d97

Browse files
Informative error message for softmax variants (#378)
* Informative error message for softmax variants * Updated message to be more informative Co-authored-by: Michael Abbott <[email protected]> Co-authored-by: Michael Abbott <[email protected]>
1 parent 9934079 commit ba04d97

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

src/softmax.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,3 +142,9 @@ function logsumexp(x::AbstractArray; dims = :)
142142
max_ = maximum(x; dims = dims)
143143
max_ .+ log.(sum(exp.(x .- max_); dims = dims))
144144
end
145+
146+
# Informative error message if any of the softmax variants is called with a number
147+
for f in (:softmax, :logsoftmax, :softmax!, :logsoftmax!, :logsumexp)
148+
@eval $(f)(x::Number, args...) =
149+
error("`", $(string(f)), "(x)` called with a number, but it expects an array. Usually this is because a layer like `Dense(3,4,softmax)` is broadcasting it like an activation function; `softmax` needs to be outside the layer.")
150+
end

0 commit comments

Comments
 (0)