Skip to content

Commit 5dd04df

Browse files
authored
Widen activation broadcast rules (#433)
* widen broadcast rules * add a test
1 parent 023cd3d commit 5dd04df

File tree

2 files changed

+10
-2
lines changed

2 files changed

+10
-2
lines changed

src/activations.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -877,7 +877,7 @@ for (f, dfdx) in UNARY_ACTS
877877

878878
pullback = Symbol(:broadcasted_, f, :_pullback)
879879
@eval function rrule(::typeof(broadcasted),
880-
::typeof($f), x::Numeric)
880+
::typeof($f), x::Union{Numeric, Broadcast.Broadcasted})
881881
Ω = $f.(x)
882882
function $pullback(dΩ)
883883
x_thunk = InplaceableThunk(
@@ -908,7 +908,7 @@ for (f, dfdx1, dfdx2) in BINARY_ACTS
908908
pullback = Symbol(:broadcasted_, f, :_pullback_2arg)
909909
@eval function rrule(::typeof(broadcasted),
910910
::typeof($f),
911-
x1::Numeric, x2::Number)
911+
x1::Union{Numeric, Broadcast.Broadcasted}, x2::Number)
912912
Ω = $f.(x1, x2)
913913
## Allowing x2::Array would allow size(Ω) != size(x1), which is not handled here:
914914
$pullback(dΩ) = (NoTangent(), NoTangent(), @.(dΩ * $dfdx1), NO_ACT_GRAD)

test/activations.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,14 @@ has_rule(a) = rrule(a, 1f0) === nothing ? "(no rule)" : ""
319319
end
320320
end
321321

322+
using Base.Broadcast: broadcasted
323+
324+
@testset "lazy broadcasting" begin
325+
# ChainRules returns a Broadcasted, check these rules accept it
326+
@test rrule(broadcasted, relu, rrule(broadcasted, +, [1,2], 3)[1]) != nothing
327+
@test rrule(broadcasted, leakyrelu, rrule(broadcasted, +, [1,2], 3)[1], 0.2) != nothing
328+
end
329+
322330
@testset "Gradient correctness" begin
323331

324332
local rng = StableRNG(17)

0 commit comments

Comments
 (0)