- 
          
- 
                Notifications
    You must be signed in to change notification settings 
- Fork 129
          Add bias_act!
          #457
        
          New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
          
     Merged
      
        
      
    
  
     Merged
                    
  
    Add bias_act!
  
  #457
                      Changes from all commits
      Commits
    
    
            Show all changes
          
          
            11 commits
          
        
        Select commit
          Hold shift + click to select a range
      
      554f339
              
                sometimes-in-place bias_act
              
              
                mcabbott 13ab2e7
              
                update after dropout PR
              
              
                mcabbott f08fc0b
              
                add to docs
              
              
                mcabbott a9136e7
              
                also fix two unrelated docstring which just told you what the functio…
              
              
                mcabbott 83882de
              
                tidy & un-comment
              
              
                mcabbott 419725c
              
                comment out 2nd path again
              
              
                mcabbott dbf39d4
              
                add Returns for 1.6
              
              
                mcabbott 791531a
              
                upgrade tests
              
              
                mcabbott 7b04b15
              
                more tests
              
              
                mcabbott c9a5722
              
                skip hardσ tests
              
              
                mcabbott cd99b77
              
                Update test/bias_act.jl
              
              
                mcabbott File filter
Filter by extension
Conversations
          Failed to load comments.   
        
        
          
      Loading
        
  Jump to
        
          Jump to file
        
      
      
          Failed to load files.   
        
        
          
      Loading
        
  Diff view
Diff view
There are no files selected for viewing
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              
              
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              
              
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              
              | Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1,107 @@ | ||
|  | ||
| using NNlib: fast_act, tanh_fast | ||
| using ChainRulesCore | ||
|  | ||
| const RCR = RuleConfig{>:HasReverseMode} | ||
|  | ||
| # This just saves typing `only.(only.(` many times: | ||
| @inline only_derivative(y,f::F,x) where F = only(only(ChainRulesCore.derivatives_given_output(y, f, x))) | ||
|  | ||
| # This has no methods, used for testing whether `derivatives_given_output(Ω, f, x)` | ||
| # is independent of `x`, as `_return_type` says `Union{}` when calling is an error. | ||
| struct NotaNumber <: Real end | ||
|  | ||
| """ | ||
| bias_act!(σ, x, b) | ||
| This is equivalent to `x .= σ.(x .+ b)`, also replacing `sigmoid` & `tanh` | ||
| with `sigmoid_fast` & `tanh_fast`. | ||
| It will only overwrite `x` when `x isa StridedArray{<:AbstractFloat}`. | ||
| When used within a gradient, it will overwrite only when `σ` has | ||
| a method of `derivatives_given_output` which does not need the input at all. | ||
| Such methods are defined by e.g. `@scalar_rule relu(x) Ω > 0` where the derivative | ||
| contains only `Ω` (the output) not `x`. | ||
| !!! warning | ||
| This is not safe to use if `x` is still needed for the gradient | ||
| of some other function. Incorrect use will give silently wrong answers. | ||
| It is intended mainly for Flux layers, in which the previous operation is | ||
| known to be safe, e.g. `bias_act!(σ, weight * input, bias)` for a `Dense` layer. | ||
| """ | ||
| bias_act!(σ::Function, x::StridedArray{<:AbstractFloat}, b::AbstractArray{<:Union{Bool, AbstractFloat}}) = | ||
| _fast_broadcast!(fast_act(σ, x)∘(+), x, b) # works around a SIMD bug | ||
|  | ||
| function bias_act!(σ::Function, x::StridedArray{<:AbstractFloat}, b::Bool) | ||
| b === true && error("bias=true is not accepted; layer constructors shoud guarantee this") | ||
| _fast_broadcast!(fast_act(σ, x), x) | ||
| end | ||
|  | ||
| function bias_act!(::typeof(identity), x::StridedArray{<:AbstractFloat}, b::Bool) | ||
| b === true && error("bias=true is not accepted; layer constructors shoud guarantee this") | ||
| x # pass-through | ||
| end | ||
|  | ||
| function bias_act!(σ::Function, x::AbstractArray, b) | ||
| b === true && error("bias=true is not accepted; layer constructors shoud guarantee this") | ||
| fast_act(σ, x).(x .+ b) # fallback | ||
| end | ||
|  | ||
| function ChainRulesCore.rrule(cfg::RCR, ::typeof(bias_act!), σ::F, x::AbstractArray{T,N}, b::B) where {F,T,N,B} | ||
| biasgrad = if eltype(B) !== Bool | ||
| # Summing over ndims(x)+1 is a trick to make b_dims type-stable | ||
| dims = ntuple(d -> size(b,d)==1 ? d : N+1, N) | ||
| _biasgrad(dx) = reshape(sum(dx; dims), size(b)) | ||
| else | ||
| Returns(NoTangent()) | ||
| end | ||
|  | ||
| # Fast path: it is now safe to overwrite x, since this is not needed for gradient of σ | ||
| if isconcretetype(Core.Compiler._return_type(only_derivative, Tuple{T, F, NotaNumber})) | ||
| Ω = bias_act!(σ, x, b) # now x === Ω, when x isa StridedArray{<:AbstractFloat} | ||
| function bias_act!_fastback(Δ) | ||
| # Tempting to overwrite x again, but only safe if you call pullback at most once, | ||
| # TODO with e.g. https://github.com/FluxML/Zygote.jl/pull/1340 | ||
| # https://github.com/JuliaDiff/ChainRulesCore.jl/pull/592 | ||
| dx = only_derivative.(Ω, σ, NotaNumber()) .* unthunk(Δ) | ||
| return (NoTangent(), NoTangent(), dx, biasgrad(dx)) | ||
| end | ||
| return Ω, bias_act!_fastback | ||
|  | ||
| # # Slower path: can't overwrite x, but can use derivatives_given_output | ||
| # # This case is WRONG and tests fail, but not sure why | ||
| # elseif isconcretetype(Core.Compiler._return_type(only_derivative, Tuple{T, F, T})) | ||
| # Ω2 = fast_act(σ, x).(x) .+ b | ||
| # @show σ b | ||
| # function bias_act!_back2(Δ) | ||
| # dx = only_derivative.(Ω2, σ, x .+ b) .* unthunk(Δ) | ||
| # return (NoTangent(), NoTangent(), dx, biasgrad(dx)) | ||
| # end | ||
| # return Ω2, bias_act!_back2 | ||
|  | ||
| # Fallback path: let AD handle the broadcast | ||
| else | ||
| Ω3, back = rrule_via_ad(cfg, broadcast, fast_act(σ, x), bias_act!(identity, x, b)) | ||
| @inline function bias_act!_slowback(Δ) | ||
| _, _, dx = back(Δ) | ||
| return (NoTangent(), NoTangent(), dx, biasgrad(dx)) | ||
| end | ||
| return Ω3, bias_act!_slowback | ||
| end | ||
| end | ||
|  | ||
| # Two easy cases with identity | ||
| function rrule(cfg::RCR, ::typeof(bias_act!), ::typeof(identity), x::AbstractArray{T,N}, b::B) where {T,N,B} | ||
| dims = ntuple(d -> size(b,d)==1 ? d : N+1, N) | ||
| biasgrad(dx) = reshape(sum(dx; dims), size(b)) | ||
| function bias_act!_idback(Δ) | ||
| dx = unthunk(Δ) | ||
| return (NoTangent(), NoTangent(), dx, biasgrad(dx)) | ||
| end | ||
| return bias_act!(identity, x, b), bias_act!_idback | ||
| end | ||
| function rrule(cfg::RCR, ::typeof(bias_act!), ::typeof(identity), x::AbstractArray{T,N}, b::Bool) where {T,N} | ||
| bias_act!_trivial(Δ) = (NoTangent(), NoTangent(), Δ, NoTangent()) | ||
| return x, bias_act!_trivial | ||
| end | ||
|  | 
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              
              
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              
              
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              
              | Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1,114 @@ | ||
| using NNlib, Zygote, ChainRulesCore, Test | ||
| using Zygote: ForwardDiff | ||
|  | ||
| ACTIVATION_FUNCTIONS = | ||
| [@eval($a) for a in NNlib.ACTIVATIONS] | ||
|  | ||
| @testset "bias_act!" begin | ||
| x = randn(3,4) | ||
| b = randn(3) | ||
| @test @inferred(bias_act!(identity, x, false)) === x # pass-through | ||
| @test @inferred(bias_act!(identity, copy(x), b)) ≈ (x .+ b) | ||
| @test @inferred(bias_act!(relu, copy(x), b)) ≈ relu.(x .+ b) | ||
| @test @inferred(bias_act!(tanh, copy(x), b)) ≈ tanh.(x .+ b) | ||
| @test @inferred(bias_act!(tanh, copy(x), false)) ≈ tanh.(x) | ||
|  | ||
| # Check that it does overwrite: | ||
| x32 = rand(Float32, 3, 4); x32copy = copy(x32) | ||
| @test @inferred(bias_act!(cbrt, x32, b)) ≈ cbrt.(x32copy .+ b) | ||
| @test x32 ≈ cbrt.(x32copy .+ b) | ||
|  | ||
| x32 = rand(Float32, 3, 4); x32copy = copy(x32) # without bias | ||
| @test @inferred(bias_act!(tanh, x32, false)) ≈ tanh.(x32copy) | ||
| @test x32 ≈ tanh.(x32copy) | ||
|  | ||
| x32 = rand(Float32, 3, 4); x32copy = copy(x32) # now check gradient rule | ||
| y, back = rrule(Zygote.ZygoteRuleConfig(), bias_act!, relu, x32, b) | ||
| @test y ≈ x32 ≈ relu.(x32copy .+ b) | ||
|  | ||
| x32 = rand(Float32, 3, 4); x32copy = copy(x32) # without bias | ||
| y, back = rrule(Zygote.ZygoteRuleConfig(), bias_act!, relu, x32, false) | ||
| @test y ≈ x32 ≈ relu.(x32copy) | ||
|  | ||
| # Check that it doesn't try to overwrite non-float arrays: | ||
| xint = rand(-3:3, 3, 4) | ||
| bint = rand(-2:2, 3) | ||
| @test bias_act!(identity, copy(xint), bint) ≈ xint .+ bint | ||
| @test bias_act!(tanh, copy(xint), bint) ≈ tanh.(xint .+ bint) | ||
| @test bias_act!(tanh, copy(xint), false) ≈ tanh.(xint) | ||
|  | ||
| # Reject bias===true so that Bool means one thing: | ||
| @test_throws Exception bias_act!(identity, rand(3), true) | ||
| @test_throws Exception bias_act!(cbrt, rand(3), true) | ||
| @test_throws Exception bias_act!(cbrt, rand(1:3, 3), true) | ||
|  | ||
| @testset "gradient with $fun" for fun in vcat([identity, tanh, cbrt], | ||
| ACTIVATION_FUNCTIONS, | ||
| [x->x, x -> 1/(x^2+2), x -> leakyrelu(x, 0.33)]) | ||
| # Only some of these go the fast path, `cbrt` is an example of a function NNlib knows nothing about. | ||
| fun == rrelu && continue # this one is randomised! | ||
| fun == hardσ && continue # this one has heisenbugs, not solved by discontinuity-avoidance code below | ||
|  | ||
| @test bias_act!(fun, copy(x), b) ≈ fun.(x .+ b) | ||
| @test bias_act!(fun, copy(x), false) ≈ fun.(x) | ||
|  | ||
| gx = ForwardDiff.gradient(x -> sum(bias_act!(fun, copy(x), b)), x) | ||
| gxplus = ForwardDiff.gradient(x -> sum(bias_act!(fun, copy(x), b)), x .+ eps()) | ||
| gxminus = ForwardDiff.gradient(x -> sum(bias_act!(fun, copy(x), b)), x .- eps()) | ||
| if !(gx ≈ gxplus ≈ gxminus) | ||
| @warn "skipping gradient tests due to discontinuity" fun x b | ||
| continue | ||
| end | ||
| @test gx ≈ Zygote.gradient(x -> sum(bias_act!(fun, copy(x), b)), x)[1] | ||
|  | ||
| gx2 = ForwardDiff.gradient(x -> sum(bias_act!(fun, copy(x), false)), x) | ||
| gx2plus = ForwardDiff.gradient(x -> sum(bias_act!(fun, copy(x), false)), x .- eps()) | ||
| gx2minus = ForwardDiff.gradient(x -> sum(bias_act!(fun, copy(x), false)), x .- eps()) | ||
| if !(gx2 ≈ gx2plus ≈ gx2minus) | ||
| @warn "skipping gradient tests due to discontinuity" fun x | ||
| continue | ||
| end | ||
| @test gx2 ≈ Zygote.gradient(x -> sum(bias_act!(fun, copy(x), false)), x)[1] | ||
|  | ||
| gb = ForwardDiff.gradient(b -> sum(bias_act!(fun, copy(x), b)), b) | ||
| @test gb ≈ Zygote.gradient(b -> sum(bias_act!(fun, copy(x), b)), b)[1] | ||
|  | ||
| @test Zygote.gradient(b -> sum(bias_act!(fun, copy(x), b)), false) == (nothing,) | ||
| @test Zygote.gradient(b -> sum(bias_act!(fun, copy(x), b)), b .> 0) == (nothing,) | ||
| end | ||
|  | ||
| @testset "gradient for fast_broadcast!" begin | ||
| # Gradient definition is just to disable mutation inside 2nd order AD | ||
| gx = ForwardDiff.gradient(x -> sum(NNlib._fast_broadcast!(cbrt∘(+), copy(x), b)), x) | ||
| @test gx ≈ Zygote.gradient(x -> sum(NNlib._fast_broadcast!(cbrt∘(+), copy(x), b)), x)[1] | ||
|  | ||
| # relu should take the fast path | ||
| g2 = ForwardDiff.gradient(x) do x | ||
| sum(abs2, Zygote.gradient(x -> sum(abs2, bias_act!(relu, copy(x), b)), x)[1]) | ||
| end | ||
| @test_skip gx ≈ Zygote.gradient(x) do x # Here global variable b causes an error | ||
| sum(abs2,Zygote. gradient(x -> sum(abs2, bias_act!(relu, copy(x), b)), x)[1]) | ||
| end | ||
| # Can't differentiate foreigncall expression $(Expr(:foreigncall, :(:jl_eqtable_get), Any, svec(Any, Any, Any), 0, :(:ccall), %5, %3, %4)). | ||
| # [5] (::typeof(∂(accum_global)))(Δ::Nothing) | ||
| @test g2 ≈ Zygote.gradient(x, b) do x, b | ||
| sum(abs2, Zygote.gradient((x, b) -> sum(abs2, bias_act!(relu, copy(x), b)), x, b)[1]) | ||
| end[1] | ||
|  | ||
| g3 = ForwardDiff.gradient(x) do x | ||
| sum(abs2, Zygote.gradient((x, b) -> sum(abs2, bias_act!(swish, copy(x), b)), x, b)[1]) | ||
| end | ||
| @test g3 ≈ Zygote.gradient(x, b) do x, b | ||
| sum(abs2, Zygote.gradient((x, b) -> sum(abs2, bias_act!(swish, copy(x), b)), x, b)[1]) | ||
| end[1] | ||
|  | ||
| # Anon function sure to take the generic path | ||
| g4 = ForwardDiff.gradient(x) do x | ||
| sum(abs2, Zygote.gradient((x, b) -> sum(abs2, bias_act!(y -> cbrt(y/3), copy(x), b)), x, b)[1]) | ||
| end | ||
| @test g4 ≈ Zygote.gradient(x, b) do x, b | ||
| sum(abs2, Zygote.gradient((x, b) -> sum(abs2, bias_act!(y -> cbrt(y/3), copy(x), b)), x, b)[1]) | ||
| end[1] | ||
| end | ||
| end | ||
|  | ||
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              
              
  Add this suggestion to a batch that can be applied as a single commit.
  This suggestion is invalid because no changes were made to the code.
  Suggestions cannot be applied while the pull request is closed.
  Suggestions cannot be applied while viewing a subset of changes.
  Only one suggestion per line can be applied in a batch.
  Add this suggestion to a batch that can be applied as a single commit.
  Applying suggestions on deleted lines is not supported.
  You must change the existing code in this line in order to create a valid suggestion.
  Outdated suggestions cannot be applied.
  This suggestion has been applied or marked resolved.
  Suggestions cannot be applied from pending reviews.
  Suggestions cannot be applied on multi-line comments.
  Suggestions cannot be applied while the pull request is queued to merge.
  Suggestion cannot be applied right now. Please check back later.
  
    
  
    
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This slightly elaborate thing is avoiding my best guess as to why there were failures on CI: hardsigmoid has discontinuities, and if
xhits them, the two gradients may not agree.But it doesn't seem to work: