Skip to content

Commit c51154b

Browse files
authored
Style improvements for rewrite.jl (#138)
1 parent eb5a3d7 commit c51154b

File tree

2 files changed

+29
-26
lines changed

2 files changed

+29
-26
lines changed

src/MutableArithmetics.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ module MutableArithmetics
99
import LinearAlgebra
1010
import SparseArrays
1111

12+
using Base.Meta # Used in rewrite.jl
13+
1214
# Performance note:
1315
# We use `Vararg` instead of splatting `...` as using `where N` forces Julia to
1416
# specialize in the number of arguments `N`. Otherwise, we get allocations and

src/rewrite.jl

Lines changed: 27 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,3 @@
1-
# Heavily inspired from `JuMP/src/parse_expr.jl` code.
2-
3-
export @rewrite
4-
51
"""
62
@rewrite(expr)
73
@@ -28,6 +24,7 @@ macro rewrite(expr)
2824
end
2925

3026
struct Zero end
27+
3128
## We need to copy `x` as it will be used as might be given by the user and be
3229
## given as first argument of `operate!!`.
3330
#Base.:(+)(zero::Zero, x) = copy_if_mutable(x)
@@ -36,15 +33,19 @@ struct Zero end
3633
function operate(::typeof(add_mul), ::Zero, args::Vararg{Any,N}) where {N}
3734
return operate(*, args...)
3835
end
36+
3937
function operate(::typeof(sub_mul), ::Zero, x)
4038
# `operate(*, x)` would redirect to `copy_if_mutable(x)` which would be a
4139
# useless copy.
4240
return operate(-, x)
4341
end
42+
4443
function operate(::typeof(sub_mul), ::Zero, x, y, args::Vararg{Any,N}) where {N}
4544
return operate(-, operate(*, x, y, args...))
4645
end
46+
4747
broadcast!!(::Union{typeof(add_mul),typeof(+)}, ::Zero, x) = copy_if_mutable(x)
48+
4849
broadcast!!(::typeof(add_mul), ::Zero, x, y) = x * y
4950

5051
# Needed in `@rewrite(1 .+ sum(1 for i in 1:0) * 1^2)`
@@ -60,6 +61,7 @@ Base.:-(z::Zero, ::Zero) = z
6061
Base.:-(z::Zero) = z
6162
Base.:+(z::Zero) = z
6263
Base.:*(z::Zero) = z
64+
6365
function Base.:/(z::Zero, x::Any)
6466
if iszero(x)
6567
throw(DivideError())
@@ -74,6 +76,7 @@ end
7476
_any_zero() = false
7577
_any_zero(::Any, args::Vararg{Any,N}) where {N} = _any_zero(args...)
7678
_any_zero(::Zero, ::Vararg{Any,N}) where {N} = true
79+
7780
function operate!!(
7881
op::Union{typeof(add_mul),typeof(sub_mul)},
7982
x,
@@ -93,8 +96,6 @@ Base.length(::Zero) = 1
9396
Base.iterate(z::Zero) = (z, nothing)
9497
Base.iterate(::Zero, ::Nothing) = nothing
9598

96-
using Base.Meta
97-
9899
# See `JuMP._try_parse_idx_set`
99100
function _try_parse_idx_set(arg::Expr)
100101
# [i=1] and x[i=1] parse as Expr(:vect, Expr(:(=), :i, 1)) and
@@ -114,10 +115,10 @@ end
114115

115116
function _parse_idx_set(arg::Expr)
116117
parse_done, idxvar, idxset = _try_parse_idx_set(arg)
117-
if parse_done
118-
return idxvar, idxset
118+
if !parse_done
119+
error("Invalid syntax: $arg")
119120
end
120-
return error("Invalid syntax: $arg")
121+
return idxvar, idxset
121122
end
122123

123124
"""
@@ -143,8 +144,7 @@ function rewrite_generator(ex, inner)
143144
# `i + j for i in 1:2 for j in 1:2` is a `flatten` expression
144145
if isexpr(ex, :flatten)
145146
return rewrite_generator(ex.args[1], inner)
146-
end
147-
if !isexpr(ex, :generator)
147+
elseif !isexpr(ex, :generator)
148148
return inner(ex)
149149
end
150150
# `i + j for i in 1:2, j in 1:2` is a `generator` expression
@@ -204,19 +204,18 @@ function _parse_generator(
204204
@assert isexpr(inner_factor.args[2], :generator) ||
205205
isexpr(inner_factor.args[2], :flatten)
206206
header = inner_factor.args[1]
207-
if _is_sum(header)
208-
_parse_generator_sum(
209-
vectorized,
210-
minus,
211-
inner_factor.args[2],
212-
current_sum,
213-
left_factors,
214-
right_factors,
215-
new_var,
216-
)
217-
else
207+
if !_is_sum(header)
218208
error("Expected `sum` outside generator expression; got `$header`.")
219209
end
210+
return _parse_generator_sum(
211+
vectorized,
212+
minus,
213+
inner_factor.args[2],
214+
current_sum,
215+
left_factors,
216+
right_factors,
217+
new_var,
218+
)
220219
end
221220

222221
function _parse_generator_sum(
@@ -252,6 +251,7 @@ function _parse_generator_sum(
252251
end
253252

254253
_is_complex_expr(ex) = isa(ex, Expr) && !isexpr(ex, :ref)
254+
255255
function _is_decomposable_with_factors(ex)
256256
# `.+` and `.-` do not support being decomposed if `left_factors` or
257257
# `right_factors` are not empty. Otherwise, for instance
@@ -265,7 +265,7 @@ end
265265
rewrite(x)
266266
267267
Rewrite the expression `x` as specified in [`@rewrite`](@ref).
268-
Return a variable name as `Symbol` and the rewritten expression assigning the
268+
Returns a variable name as `Symbol` and the rewritten expression assigning the
269269
value of the expression `x` to the variable.
270270
"""
271271
function rewrite(x)
@@ -302,9 +302,8 @@ function _is_comparison(ex::Expr)
302302
else
303303
return false
304304
end
305-
else
306-
return false
307305
end
306+
return false
308307
end
309308

310309
# `x[i = 1]` is a somewhat common user error. Catch it here.
@@ -315,6 +314,7 @@ function _has_assignment_in_ref(ex::Expr)
315314
return any(_has_assignment_in_ref, ex.args)
316315
end
317316
end
317+
318318
_has_assignment_in_ref(other) = false
319319

320320
function _rewrite_sum(
@@ -347,14 +347,15 @@ function _rewrite_sum(
347347
return output, block
348348
end
349349

350-
function _start_summing(current_sum::Nothing, first_term::Function)
350+
function _start_summing(::Nothing, first_term::Function)
351351
variable = gensym()
352352
return Expr(
353353
:block,
354354
:($variable = MutableArithmetics.Zero()),
355355
first_term(variable),
356356
)
357357
end
358+
358359
function _start_summing(current_sum::Symbol, first_term::Function)
359360
return first_term(current_sum)
360361
end

0 commit comments

Comments
 (0)