1- # Heavily inspired from `JuMP/src/parse_expr.jl` code.
2-
3- export @rewrite
4-
51"""
62 @rewrite(expr)
73
@@ -28,6 +24,7 @@ macro rewrite(expr)
2824end
2925
3026struct Zero end
27+
3128# # We need to copy `x` as it will be used as might be given by the user and be
3229# # given as first argument of `operate!!`.
3330# Base.:(+)(zero::Zero, x) = copy_if_mutable(x)
@@ -36,15 +33,19 @@ struct Zero end
3633function operate (:: typeof (add_mul), :: Zero , args:: Vararg{Any,N} ) where {N}
3734 return operate (* , args... )
3835end
36+
3937function operate (:: typeof (sub_mul), :: Zero , x)
4038 # `operate(*, x)` would redirect to `copy_if_mutable(x)` which would be a
4139 # useless copy.
4240 return operate (- , x)
4341end
42+
4443function operate (:: typeof (sub_mul), :: Zero , x, y, args:: Vararg{Any,N} ) where {N}
4544 return operate (- , operate (* , x, y, args... ))
4645end
46+
4747broadcast!! (:: Union{typeof(add_mul),typeof(+)} , :: Zero , x) = copy_if_mutable (x)
48+
4849broadcast!! (:: typeof (add_mul), :: Zero , x, y) = x * y
4950
5051# Needed in `@rewrite(1 .+ sum(1 for i in 1:0) * 1^2)`
@@ -60,6 +61,7 @@ Base.:-(z::Zero, ::Zero) = z
6061Base.:- (z:: Zero ) = z
6162Base.:+ (z:: Zero ) = z
6263Base.:* (z:: Zero ) = z
64+
6365function Base.:/ (z:: Zero , x:: Any )
6466 if iszero (x)
6567 throw (DivideError ())
7476_any_zero () = false
7577_any_zero (:: Any , args:: Vararg{Any,N} ) where {N} = _any_zero (args... )
7678_any_zero (:: Zero , :: Vararg{Any,N} ) where {N} = true
79+
7780function operate!! (
7881 op:: Union{typeof(add_mul),typeof(sub_mul)} ,
7982 x,
@@ -93,8 +96,6 @@ Base.length(::Zero) = 1
9396Base. iterate (z:: Zero ) = (z, nothing )
9497Base. iterate (:: Zero , :: Nothing ) = nothing
9598
96- using Base. Meta
97-
9899# See `JuMP._try_parse_idx_set`
99100function _try_parse_idx_set (arg:: Expr )
100101 # [i=1] and x[i=1] parse as Expr(:vect, Expr(:(=), :i, 1)) and
@@ -114,10 +115,10 @@ end
114115
115116function _parse_idx_set (arg:: Expr )
116117 parse_done, idxvar, idxset = _try_parse_idx_set (arg)
117- if parse_done
118- return idxvar, idxset
118+ if ! parse_done
119+ error ( " Invalid syntax: $arg " )
119120 end
120- return error ( " Invalid syntax: $arg " )
121+ return idxvar, idxset
121122end
122123
123124"""
@@ -143,8 +144,7 @@ function rewrite_generator(ex, inner)
143144 # `i + j for i in 1:2 for j in 1:2` is a `flatten` expression
144145 if isexpr (ex, :flatten )
145146 return rewrite_generator (ex. args[1 ], inner)
146- end
147- if ! isexpr (ex, :generator )
147+ elseif ! isexpr (ex, :generator )
148148 return inner (ex)
149149 end
150150 # `i + j for i in 1:2, j in 1:2` is a `generator` expression
@@ -204,19 +204,18 @@ function _parse_generator(
204204 @assert isexpr (inner_factor. args[2 ], :generator ) ||
205205 isexpr (inner_factor. args[2 ], :flatten )
206206 header = inner_factor. args[1 ]
207- if _is_sum (header)
208- _parse_generator_sum (
209- vectorized,
210- minus,
211- inner_factor. args[2 ],
212- current_sum,
213- left_factors,
214- right_factors,
215- new_var,
216- )
217- else
207+ if ! _is_sum (header)
218208 error (" Expected `sum` outside generator expression; got `$header `." )
219209 end
210+ return _parse_generator_sum (
211+ vectorized,
212+ minus,
213+ inner_factor. args[2 ],
214+ current_sum,
215+ left_factors,
216+ right_factors,
217+ new_var,
218+ )
220219end
221220
222221function _parse_generator_sum (
@@ -252,6 +251,7 @@ function _parse_generator_sum(
252251end
253252
254253_is_complex_expr (ex) = isa (ex, Expr) && ! isexpr (ex, :ref )
254+
255255function _is_decomposable_with_factors (ex)
256256 # `.+` and `.-` do not support being decomposed if `left_factors` or
257257 # `right_factors` are not empty. Otherwise, for instance
265265 rewrite(x)
266266
267267Rewrite the expression `x` as specified in [`@rewrite`](@ref).
268- Return a variable name as `Symbol` and the rewritten expression assigning the
268+ Returns a variable name as `Symbol` and the rewritten expression assigning the
269269value of the expression `x` to the variable.
270270"""
271271function rewrite (x)
@@ -302,9 +302,8 @@ function _is_comparison(ex::Expr)
302302 else
303303 return false
304304 end
305- else
306- return false
307305 end
306+ return false
308307end
309308
310309# `x[i = 1]` is a somewhat common user error. Catch it here.
@@ -315,6 +314,7 @@ function _has_assignment_in_ref(ex::Expr)
315314 return any (_has_assignment_in_ref, ex. args)
316315 end
317316end
317+
318318_has_assignment_in_ref (other) = false
319319
320320function _rewrite_sum (
@@ -347,14 +347,15 @@ function _rewrite_sum(
347347 return output, block
348348end
349349
350- function _start_summing (current_sum :: Nothing , first_term:: Function )
350+ function _start_summing (:: Nothing , first_term:: Function )
351351 variable = gensym ()
352352 return Expr (
353353 :block ,
354354 :($ variable = MutableArithmetics. Zero ()),
355355 first_term (variable),
356356 )
357357end
358+
358359function _start_summing (current_sum:: Symbol , first_term:: Function )
359360 return first_term (current_sum)
360361end
0 commit comments