- 
          
- 
                Notifications
    You must be signed in to change notification settings 
- Fork 615
Some fast paths + type fixes #2137
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
cb43150
              7bf7594
              8e9a5cc
              fcbc7b4
              ce1cf88
              File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
|  | @@ -210,7 +210,7 @@ true | |
| ``` | ||
| """ | ||
| struct LayerNorm{F,D,T,N} | ||
| λ::F | ||
| λ::F # this field is not used | ||
| diag::D | ||
| ϵ::T | ||
| size::NTuple{N,Int} | ||
|  | @@ -254,16 +254,16 @@ function _norm_layer_forward( | |
| end | ||
| end | ||
|  | ||
| o = _norm_layer_forward(x, μ, σ², l.ϵ) | ||
| hasaffine(l) || return l.λ.(o) | ||
|  | ||
| γ = reshape(l.γ, affine_shape) | ||
| β = reshape(l.β, affine_shape) | ||
| return l.λ.(γ .* o .+ β) | ||
| s = (inv∘sqrt).(σ² .+ l.ϵ) # faster to un-fuse this, fewer inv∘sqrt calls | ||
| if hasaffine(l) | ||
| γ = reshape(l.γ, affine_shape) # ideally reshape on construction? | ||
| β = reshape(l.β, affine_shape) | ||
| return l.λ.(γ .* s .* (x .- μ) .+ β) | ||
| else | ||
| return l.λ.(s .* (x .- μ)) | ||
| end | ||
| end | ||
|  | ||
| @inline _norm_layer_forward(x, μ, σ², ϵ) = (x .- μ) ./ sqrt.(σ² .+ ϵ) | ||
|  | ||
| function _track_stats!( | ||
| bn, x::AbstractArray{T, N}, μ, σ², reduce_dims, | ||
| ) where {T, N} | ||
|  | @@ -356,10 +356,9 @@ end | |
| @functor BatchNorm | ||
| trainable(bn::BatchNorm) = hasaffine(bn) ? (β = bn.β, γ = bn.γ) : (;) | ||
|  | ||
| function (BN::BatchNorm)(x) | ||
| @assert size(x, ndims(x)-1) == BN.chs | ||
| N = ndims(x) | ||
| reduce_dims = [1:N-2; N] | ||
| function (BN::BatchNorm)(x::AbstractArray{T,N}) where {T,N} | ||
| size(x, N-1) == BN.chs || error("BatchNorm expected an input with $(BN.chs) channels, got size(x) == $(size(x))") | ||
| reduce_dims = ntuple(d -> d + (d==N-1), N-1) # i.e. 1:N with N-1 removed | ||
| 
      Comment on lines
    
      +360
     to 
      +361
    
   There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Might as well take the opportunity to mark these lines and the definition of  | ||
| affine_shape = ntuple(i -> i == N-1 ? size(x, N-1) : 1, N) | ||
| return _norm_layer_forward(BN, x; reduce_dims, affine_shape) | ||
| end | ||
|  | ||
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
|  | @@ -166,11 +166,11 @@ end | |
| end | ||
|  | ||
| # with activation function | ||
| let m = BatchNorm(2, sigmoid), x = [1.0 3.0 5.0; | ||
| 2.0 4.0 6.0] | ||
| let m = BatchNorm(2, sigmoid) | ||
| x = Float32[1.0 3.0 5.0; 2.0 4.0 6.0] | ||
| y = m(x) | ||
| @test isapprox(y, sigmoid.((x .- m.μ) ./ sqrt.(m.σ² .+ m.ϵ)), atol = 1.0e-7) | ||
| @inferred m(x) | ||
| @inferred m(x) # fails when x::Matrix{Float64}, do we care? | ||
| There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do you know why this fails? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I do not. Checking branches  | ||
| end | ||
|  | ||
| let m = trainmode!(BatchNorm(2)), x = reshape(Float32.(1:6), 3, 2, 1) | ||
|  | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change hits the following failure:
So it's differentiating this:
https://github.com/JuliaDiff/ChainRules.jl/blob/9a405f732758552cd945a110adb6828a997887a8/src/rulesets/Statistics/statistics.jl#L7
and differentiating the rule for
unique, which doesn't handle this case.Zygote differentiates so many things it need not touch, surely this adds startup time... you only notice when it fails.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One of its fatal flaws, you might say. Usually first-order differentiation is well-behaved because control flow and possible mutation are hidden away, but all bets are off with second order...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Even first order, I think it does a lot which it need not do. Just most of the resulting errors have already been found. Same thing in trying out Diffractor -- lots of errors from obscure code calculating indices for views or whatever, to a human obviously non-diff.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This one fixed in JuliaDiff/ChainRules.jl#687
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's one definite benefit of tracing/overload-based ADs. Anything not numerically interesting gets ignored or falls away in the final tape/graph.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yup. I presume that any kind of activity tracking would also let you eliminate most off-track things. Maybe declaring integers (and all structs not containing floats) non-diff would also help.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's certainly a lot we could learn from projects like differentiable Swift (which uses activity analysis). It seems unlikely Zygote will be where such knowledge is applied given how poorly integrated it is with the compiler.