1
1
# These are some macros (and supporting functions) to make it easier to define rules.
2
- using Base. Meta
3
2
4
- macro strip_linenos (expr)
5
- return esc (Base. remove_linenums! (expr))
6
- end
3
+ # ###########################################################################################
4
+ # ## @scalar_rule
7
5
8
6
"""
9
7
@scalar_rule(f(x₁, x₂, ...),
@@ -88,7 +86,6 @@ macro scalar_rule(call, maybe_setup, partials...)
88
86
frule_expr = scalar_frule_expr (__source__, f, call, setup_stmts, inputs, partials)
89
87
rrule_expr = scalar_rrule_expr (__source__, f, call, setup_stmts, inputs, partials)
90
88
91
- # ###########################################################################
92
89
# Final return: building the expression to insert in the place of this macro
93
90
code = quote
94
91
if ! ($ f isa Type) && fieldcount (typeof ($ f)) > 0
@@ -114,7 +111,6 @@ returns (in order) the correctly escaped:
114
111
- `partials`: which are all `Expr{:tuple,...}`
115
112
"""
116
113
function _normalize_scalarrules_macro_input (call, maybe_setup, partials)
117
- # ###########################################################################
118
114
# Setup: normalizing input form etc
119
115
120
116
if Meta. isexpr (maybe_setup, :macrocall ) && maybe_setup. args[1 ] == Symbol (" @setup" )
@@ -275,6 +271,9 @@ propagator_name(f::Expr, propname::Symbol) = propagator_name(f.args[end], propna
275
271
propagator_name (fname:: Symbol , propname:: Symbol ) = Symbol (fname, :_ , propname)
276
272
propagator_name (fname:: QuoteNode , propname:: Symbol ) = propagator_name (fname. value, propname)
277
273
274
+ # ###########################################################################################
275
+ # ## @non_differentiable
276
+
278
277
"""
279
278
@non_differentiable(signature_expression)
280
279
@@ -324,7 +323,7 @@ macro non_differentiable(sig_expr)
324
323
:($ (primal_name)($ (unconstrained_args... )))
325
324
else
326
325
normal_args = unconstrained_args[1 : end - 1 ]
327
- var_arg = unconstrained_args [end ]
326
+ var_arg = s [end ]
328
327
:($ (primal_name)($ (normal_args... ), $ (var_arg). .. ))
329
328
end
330
329
@@ -393,10 +392,13 @@ function _nondiff_rrule_expr(__source__, primal_sig_parts, primal_invoke)
393
392
end
394
393
end
395
394
396
-
397
- # ##########
395
+ # ###########################################################################################
398
396
# Helpers
399
397
398
+ macro strip_linenos (expr)
399
+ return esc (Base. remove_linenums! (expr))
400
+ end
401
+
400
402
"""
401
403
_isvararg(expr)
402
404
0 commit comments