Skip to content

Commit 3751802

Browse files
authored
Handle broadcasted operators (#12)
1 parent 54e4b7d commit 3751802

File tree

2 files changed

+85
-17
lines changed

2 files changed

+85
-17
lines changed

src/macros.jl

Lines changed: 51 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -79,10 +79,6 @@ const op_checked = Dict(
7979
:- => :(checked_sub),
8080
:* => :(checked_mul),
8181
:^ => :(checked_pow),
82-
:+= => :(checked_add),
83-
:-= => :(checked_sub),
84-
:*= => :(checked_mul),
85-
:^= => :(checked_pow),
8682
:abs => :(checked_abs),
8783
)
8884

@@ -93,13 +89,27 @@ const op_unchecked = Dict(
9389
:- => :(unchecked_sub),
9490
:* => :(unchecked_mul),
9591
:^ => :(unchecked_pow),
96-
:+= => :(unchecked_add),
97-
:-= => :(unchecked_sub),
98-
:*= => :(unchecked_mul),
99-
:^= => :(unchecked_pow),
10092
:abs => :(unchecked_abs)
10193
)
10294

95+
const broadcast_op_map = Dict(
96+
:.+ => :+,
97+
:.- => :-,
98+
:.* => :*,
99+
:.^ => :^
100+
)
101+
102+
const assignment_op_map = Dict(
103+
:+= => :+,
104+
:-= => :-,
105+
:*= => :*,
106+
:^= => :^,
107+
:.+= => :.+,
108+
:.-= => :.-,
109+
:.*= => :.*,
110+
:.^= => :.^,
111+
)
112+
103113
# resolve ambiguity when `-` used as symbol
104114
unchecked_negsub(x) = unchecked_neg(x)
105115
unchecked_negsub(x, y) = unchecked_sub(x, y)
@@ -110,18 +120,34 @@ checked_negsub(x, y) = checked_sub(x, y)
110120
function replace_op!(expr::Expr, op_map::Dict)
111121
if isexpr(expr, :call)
112122
f, len = expr.args[1], length(expr.args)
113-
op = isexpr(f, :.) ? f.args[2].value : f # handle module-scoped functions
114-
if op === :+ && len == 2 # unary +
123+
op = isexpr(f, :.) ? f.args[2].value : f # handle module-scoped functions
124+
if op === :+ && len == 2 # unary +
115125
# no action required
116-
elseif op === :- && len == 2 # unary -
126+
elseif op === :- && len == 2 # unary -
117127
op = get(op_map, Symbol("unary-"), op)
118128
if isexpr(f, :.)
119129
f.args[2] = QuoteNode(op)
120130
expr.args[1] = f
121131
else
122132
expr.args[1] = op
123133
end
124-
else # arbitrary call
134+
elseif op keys(broadcast_op_map) # broadcast operators
135+
op = get(broadcast_op_map, op, op)
136+
if length(expr.args) == 2 # unary operator
137+
if op == :-
138+
expr.head = :.
139+
expr.args = [
140+
get(op_map, Symbol("unary-"), op),
141+
Expr(:tuple, expr.args[2])]
142+
end
143+
# no action required for .+
144+
else
145+
expr.head = :.
146+
expr.args = [
147+
get(op_map, op, op),
148+
Expr(:tuple, expr.args[2:end]...)]
149+
end
150+
else # arbitrary call
125151
op = get(op_map, op, op)
126152
if isexpr(f, :.)
127153
f.args[2] = QuoteNode(op)
@@ -134,7 +160,7 @@ function replace_op!(expr::Expr, op_map::Dict)
134160
a = expr.args[i]
135161
if isa(a, Expr)
136162
replace_op!(a, op_map)
137-
elseif isa(a, Symbol) # operator as symbol function argument, e.g. `fold(+, ...)`
163+
elseif isa(a, Symbol) # operator as symbol function argument, e.g. `fold(+, ...)`
138164
op = if a == :-
139165
get(op_map, Symbol("ambig-"), a)
140166
else
@@ -146,13 +172,16 @@ function replace_op!(expr::Expr, op_map::Dict)
146172
expr.args[i] = op
147173
end
148174
end
149-
elseif isexpr(expr, (:+=, :-=, :*=, :^=)) # in-place operator
175+
elseif isexpr(expr, keys(assignment_op_map)) # assignment operators
150176
target = expr.args[1]
151177
arg = expr.args[2]
152178
op = expr.head
153-
op = get(op_map, op, op)
154-
expr.head = :(=)
155-
expr.args[2] = Expr(:call, op, target, arg)
179+
op = get(assignment_op_map, op, op)
180+
expr.head = startswith(string(op), ".") ? :.= : :(=) # is there a better test?
181+
expr.args[2] = replace_op!(Expr(:call, op, target, arg), op_map)
182+
elseif isexpr(expr, :.) # broadcast function
183+
op = expr.args[1]
184+
expr.args[1] = get(op_map, op, op)
156185
elseif !isexpr(expr, :macrocall) || expr.args[1] (Symbol("@checked"), Symbol("@unchecked"))
157186
for a in expr.args
158187
if isa(a, Expr)
@@ -162,3 +191,8 @@ function replace_op!(expr::Expr, op_map::Dict)
162191
end
163192
return expr
164193
end
194+
195+
if VERSION < v"1.6"
196+
import Base.Meta: isexpr
197+
isexpr(expr, heads) = isexpr(expr, collect(heads))
198+
end

test/runtests.jl

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,40 @@ using SaferIntegers
302302
end))
303303
end
304304

305+
@testset "Broadcasted operators replaced" begin
306+
aa = fill(typemax(Int), 2)
307+
bb = fill(2, 2)
308+
cc = fill(typemin(Int), 2)
309+
@unchecked(.+cc) == cc
310+
@unchecked(.-cc) == cc
311+
@checked(.+cc) == cc
312+
@test_throws OverflowError @checked(.-cc)
313+
@unchecked(aa .+ bb) == fill(typemin(Int) + 1, 2)
314+
@test_throws OverflowError @checked aa .+ bb
315+
@unchecked(cc .- bb) == fill(typemax(Int) - 1, 2)
316+
@test_throws OverflowError @checked cc .- bb
317+
@unchecked(aa .* bb) == fill(-2, 2)
318+
@test_throws OverflowError @checked aa .* bb
319+
@unchecked(aa .^ bb) == fill(1, 2)
320+
@test_throws OverflowError @checked aa .^ bb
321+
@unchecked(abs.(cc)) == cc
322+
@test_throws OverflowError @checked abs.(cc)
323+
end
324+
325+
@testset "Broadcasted assignment operators replaced" begin
326+
aa = fill(typemax(Int), 2)
327+
bb = fill(2, 2)
328+
cc = fill(typemin(Int), 2)
329+
@unchecked(copy(aa) .+= bb) == fill(typemin(Int) + 1, 2)
330+
@test_throws OverflowError @checked aa .+ bb
331+
@unchecked(copy(cc) .-= bb) == fill(typemax(Int) - 1, 2)
332+
@test_throws OverflowError @checked cc .- bb
333+
@unchecked(copy(aa) .* bb) == fill(-2, 2)
334+
@test_throws OverflowError @checked aa .* bb
335+
@unchecked(copy(aa) .^ bb) == fill(1, 2)
336+
@test_throws OverflowError @checked aa .^ bb
337+
end
338+
305339
@testset "Elementwise array methods are replaced, and others throw" begin
306340
aa = fill(typemax(Int), 2)
307341
bb = fill(2, 2)

0 commit comments

Comments
 (0)