Skip to content

Commit aad08b5

Browse files
authored
Merge pull request #423 from theabhirath/broadcast-act
Define activation functions taking arrays as input
2 parents 4ebb419 + d651024 commit aad08b5

File tree

2 files changed

+8
-5
lines changed

2 files changed

+8
-5
lines changed

src/activations.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -731,10 +731,9 @@ function softshrink(x, λ = 0.5)
731731
ifelse(hi > 0, ifelse(lo < 0, zero(hi), lo), hi)
732732
end
733733

734-
# Provide an informative error message if activation functions are called with an array
734+
# Define broadcasts for activation functions on arrays
735735
for f in ACTIVATIONS
736-
@eval $(f)(x::AbstractArray, args...) =
737-
error("Use broadcasting (`", $(string(f)), ".(x)`) to apply activation functions to arrays.")
736+
@eval $(f)(x::AbstractArray, args...) = $(f).(x, args...)
738737
end
739738

740739
## Faster, less accurate, versions of some.

test/activations.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,10 +93,14 @@ end
9393
@testset "Array input -> error" begin
9494
x = rand(5)
9595
for a in ACTIVATION_FUNCTIONS
96-
@test_throws ErrorException a(x)
96+
@test size(a(x)) == size(x)
97+
grad = Zygote.gradient(p -> sum(a(p)), x)
98+
@test size(grad[1]) == size(x)
9799
end
98100
for a in BINARY_ACTIVATIONS
99-
@test_throws ErrorException a(x, 0.1)
101+
@test size(a(x, 0.1)) == size(x)
102+
grad = Zygote.gradient(p -> sum(a(p, 0.1)), x)
103+
@test size(grad[1]) == size(x)
100104
end
101105
end
102106

0 commit comments

Comments
 (0)