Skip to content

Commit 54e4b7d

Browse files
authored
Handle elementwise array math + change imports to usings (#13)
1 parent 5a9023a commit 54e4b7d

File tree

5 files changed

+62
-4
lines changed

5 files changed

+62
-4
lines changed

src/OverflowContexts.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ module OverflowContexts
22

33
include("macros.jl")
44
include("base_ext.jl")
5+
include("abstractarraymath_ext.jl")
56

67
export @default_checked, @default_unchecked, @checked, @unchecked,
78
unchecked_neg, unchecked_add, unchecked_sub, unchecked_mul, unchecked_negsub, unchecked_pow, unchecked_abs,

src/abstractarraymath_ext.jl

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import Base.Checked: checked_neg, checked_add, checked_sub, checked_mul
2+
if VERSION v"1.11-alpha"
3+
import Base.Checked: checked_pow
4+
end
5+
if VERSION v"1.2"
6+
using Base: broadcast_preserving_zero_d
7+
end
8+
9+
checked_neg(A::AbstractArray) = broadcast_preserving_zero_d(checked_neg, A)
10+
for f in (:checked_add, :checked_sub)
11+
@eval function ($f)(A::AbstractArray, B::AbstractArray)
12+
promote_shape(A, B) # check size compatibility
13+
broadcast_preserving_zero_d($f, A, B)
14+
end
15+
end
16+
checked_mul(A::Number, B::AbstractArray) = broadcast_preserving_zero_d(checked_mul, B, A)
17+
checked_mul(A::AbstractArray, B::Number) = broadcast_preserving_zero_d(checked_mul, A, B)
18+
checked_mul(A::AbstractArray, B::AbstractArray) = error("Checked matrix multiplication is not available")
19+
20+
checked_pow(A::AbstractArray, B::Number) = error("Checked matrix multiplication is not available")
21+
22+
# Compatibility with Julia 1.0 and 1.1
23+
if VERSION < v"1.2"
24+
if VERSION < v"1.1"
25+
@inline materialize(bc::Base.Broadcast.Broadcasted) = copy(Base.Broadcast.instantiate(bc))
26+
else
27+
using Base.Broadcast: materialize
28+
end
29+
@inline function broadcast_preserving_zero_d(f, As...)
30+
bc = Base.Broadcast.broadcasted(f, As...)
31+
r = materialize(bc)
32+
return length(axes(bc)) == 0 ? fill!(similar(bc, typeof(r)), r) : r
33+
end
34+
end

src/base_ext.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1-
import Base: promote, afoldl, @_inline_meta
1+
using Base: promote, afoldl, @_inline_meta
22
import Base.Checked: checked_neg, checked_add, checked_sub, checked_mul, checked_abs
33

44
if VERSION v"1.11-alpha"
55
import Base.Checked: checked_pow
66
else
7-
import Base: BitInteger, throw_domerr_powbysq, to_power_type
8-
import Base.Checked: mul_with_overflow, throw_overflowerr_binaryop
7+
using Base: BitInteger, throw_domerr_powbysq, to_power_type
8+
using Base.Checked: mul_with_overflow, throw_overflowerr_binaryop
99
end
1010

1111
# The Base methods have unchecked semantics, so just pass through

src/macros.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import Base.Meta: isexpr
1+
using Base.Meta: isexpr
22

33
const op_method_symbols = (:+, :-, :*, :^, :abs)
44

test/runtests.jl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,3 +301,26 @@ using SaferIntegers
301301
@test_throws OverflowError typemax(SafeInt) + 1
302302
end))
303303
end
304+
305+
@testset "Elementwise array methods are replaced, and others throw" begin
306+
aa = fill(typemax(Int), 2)
307+
bb = fill(2, 2)
308+
cc = fill(typemin(Int), 2)
309+
dd = fill(typemax(Int), 2, 2)
310+
@unchecked(+cc) == cc
311+
@unchecked(-cc) == cc
312+
@checked(+cc) == cc
313+
@test_throws OverflowError @checked(-cc)
314+
@unchecked(aa + bb) == fill(typemin(Int) + 1, 2)
315+
@test_throws OverflowError @checked aa + bb
316+
@unchecked(cc - bb) == fill(typemax(Int) - 1, 2)
317+
@test_throws OverflowError @checked cc - bb
318+
@unchecked(2aa) == fill(-2, 2)
319+
@test_throws OverflowError @checked 2aa
320+
@unchecked(aa * 2) == fill(-2, 2)
321+
@test_throws OverflowError @checked aa * 2
322+
@unchecked(aa * bb') == fill(-2, 2, 2)
323+
@test_throws ErrorException @checked aa * bb'
324+
@unchecked(dd ^ 2) == fill(2, 2, 2)
325+
@test_throws ErrorException @checked dd ^ 2
326+
end

0 commit comments

Comments
 (0)