- 
                Notifications
    You must be signed in to change notification settings 
- Fork 36
Add chain rules for function calls without dims #83
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
| Codecov ReportPatch coverage:  
 Additional details and impacted files@@            Coverage Diff             @@
##           master      #83      +/-   ##
==========================================
+ Coverage   84.13%   87.71%   +3.57%     
==========================================
  Files           2        2              
  Lines         208      236      +28     
==========================================
+ Hits          175      207      +32     
+ Misses         33       29       -4     
 Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here. ☔ View full report at Codecov. | 
| Could we document that downstream packages have to implement the two-argument methods but not the ones without  Generally, the approach in the PR won't work anyway if a package has only implemented the one-argument version. | 
| I'm not following how your suggested approach could mean that the extra rules here aren't needed? One approach could be to replace this line: AbstractFFTs.jl/src/definitions.jl Line 62 in 7d698db 
 with  $f(x::AbstractArray) =  $f(x::AbstractArray, 1:ndims(x))Then, we wouldn't need the extra rule for no  | 
| I actually went in the opposite direction and generalized the chain rules to directly work with and without a  
 It makes the rules here a bit more complex, but now no assumptions whatsoever are made on what signatures downstream implementations support, so this is arguably the most robust solution. | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMO this PR is still suboptimal and a better design would be desirable. With the latest changes now the signatures of the rules differs from the signatures of fft etc.
I think the cleanest solution is to only work with versions of fft etc. that implement dims and forward fft(x) etc. to the two-argument version. Otherwise we have to copy all rules and just remove dims everywhere. I think we should avoid such a code duplication.
| # we explicitly handle both unprovided and provided dims arguments in all rules, which | ||
| # results in some additional complexity here but means no assumptions are made on what | ||
| # signatures downstream implementations support. | ||
| function ChainRulesCore.frule(Δargs, ::typeof(fft), x::AbstractArray, dims=nothing) | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not happy about this PR because it means the signature of the AD rules is different from the signatures of fft etc. - we do not support dims = nothing in any of these methods.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A default positional argument simply expands to separate dispatches on the signatures fft(x, dims) and fft(x). The dims=nothing is just a way of sharing logic in these cases
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So I would not say the signatures are different?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My point is: You can call frule(.., fft, x, nothing) but you cannot call fft(x, nothing). This breaks the correspondence between the primal function and the rules, and makes the signatures inconsistent.
There is no clean way to share code as long as fft(x) and fft(x, dims) are completely separate. Introducing fft(x) = fft(x, 1:ndims(x)) or fft(x) = fft(x, nothing), and demanding that downstream packages implement fft(x, dims) only would solve these issues. Otherwise you have to copy the code or use something like @eval to do it for you.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a fair point, I didn't realize the nothing case. Sharing code would be easy enough with a shared helper function, e.g. replacing my current function with something like _fft_rrule and calling it in both cases, so that all the dispatches are correct. If you're opposed to that,  I can look into how to modify src/definitions.jl to support your solution.
| See my response to your comment -- I don't really agree that the signatures are different, and even explicitly writing out separate rules for  Also, I see it as an inherit benefit to avoid modifying  | 
| So it's possible to avoid code copying and get the dispatches right if one makes a helper function e.g.  | 
Addresses issue with existing chain rules observed in FluxML/Zygote.jl#1386