Skip to content

Commit 23ec91d

Browse files
committed
minor cleanup of rule_definition_tools.jl
1 parent 460a559 commit 23ec91d

File tree

2 files changed

+12
-9
lines changed

2 files changed

+12
-9
lines changed

src/ChainRulesCore.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
module ChainRulesCore
22
using Base.Broadcast: broadcasted, Broadcasted, broadcastable, materialize, materialize!
3+
using Base.Meta
34
using LinearAlgebra
45
using SparseArrays: SparseVector, SparseMatrixCSC
56
using Compat: hasfield

src/rule_definition_tools.jl

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
11
# These are some macros (and supporting functions) to make it easier to define rules.
2-
using Base.Meta
32

4-
macro strip_linenos(expr)
5-
return esc(Base.remove_linenums!(expr))
6-
end
3+
############################################################################################
4+
### @scalar_rule
75

86
"""
97
@scalar_rule(f(x₁, x₂, ...),
@@ -88,7 +86,6 @@ macro scalar_rule(call, maybe_setup, partials...)
8886
frule_expr = scalar_frule_expr(__source__, f, call, setup_stmts, inputs, partials)
8987
rrule_expr = scalar_rrule_expr(__source__, f, call, setup_stmts, inputs, partials)
9088

91-
############################################################################
9289
# Final return: building the expression to insert in the place of this macro
9390
code = quote
9491
if !($f isa Type) && fieldcount(typeof($f)) > 0
@@ -114,7 +111,6 @@ returns (in order) the correctly escaped:
114111
- `partials`: which are all `Expr{:tuple,...}`
115112
"""
116113
function _normalize_scalarrules_macro_input(call, maybe_setup, partials)
117-
############################################################################
118114
# Setup: normalizing input form etc
119115

120116
if Meta.isexpr(maybe_setup, :macrocall) && maybe_setup.args[1] == Symbol("@setup")
@@ -275,6 +271,9 @@ propagator_name(f::Expr, propname::Symbol) = propagator_name(f.args[end], propna
275271
propagator_name(fname::Symbol, propname::Symbol) = Symbol(fname, :_, propname)
276272
propagator_name(fname::QuoteNode, propname::Symbol) = propagator_name(fname.value, propname)
277273

274+
############################################################################################
275+
### @non_differentiable
276+
278277
"""
279278
@non_differentiable(signature_expression)
280279
@@ -324,7 +323,7 @@ macro non_differentiable(sig_expr)
324323
:($(primal_name)($(unconstrained_args...)))
325324
else
326325
normal_args = unconstrained_args[1:end-1]
327-
var_arg = unconstrained_args[end]
326+
var_arg = s[end]
328327
:($(primal_name)($(normal_args...), $(var_arg)...))
329328
end
330329

@@ -393,10 +392,13 @@ function _nondiff_rrule_expr(__source__, primal_sig_parts, primal_invoke)
393392
end
394393
end
395394

396-
397-
###########
395+
############################################################################################
398396
# Helpers
399397

398+
macro strip_linenos(expr)
399+
return esc(Base.remove_linenums!(expr))
400+
end
401+
400402
"""
401403
_isvararg(expr)
402404

0 commit comments

Comments
 (0)