Skip to content

Commit 356e5d9

Browse files
authored
Add saturating integer math (#7)
1 parent 3751802 commit 356e5d9

File tree

5 files changed

+288
-43
lines changed

5 files changed

+288
-43
lines changed

src/OverflowContexts.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,12 @@ module OverflowContexts
22

33
include("macros.jl")
44
include("base_ext.jl")
5+
include("base_ext_sat.jl")
56
include("abstractarraymath_ext.jl")
67

7-
export @default_checked, @default_unchecked, @checked, @unchecked,
8+
export @default_checked, @default_unchecked, @default_saturating, @checked, @unchecked, @saturating,
9+
checked_neg, checked_add, checked_sub, checked_mul, checked_pow, checked_negsub, checked_abs,
810
unchecked_neg, unchecked_add, unchecked_sub, unchecked_mul, unchecked_negsub, unchecked_pow, unchecked_abs,
9-
checked_neg, checked_add, checked_sub, checked_mul, checked_pow, checked_negsub, checked_abs
11+
saturating_neg, saturating_add, saturating_sub, saturating_mul, saturating_pow, saturating_negsub, saturating_abs
1012

1113
end # module

src/base_ext.jl

Lines changed: 34 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
1-
using Base: promote, afoldl, @_inline_meta
1+
using Base: BitInteger, promote, afoldl, @_inline_meta
22
import Base.Checked: checked_neg, checked_add, checked_sub, checked_mul, checked_abs
3+
using Base.Checked: mul_with_overflow
34

45
if VERSION v"1.11-alpha"
6+
import Base: power_by_squaring
57
import Base.Checked: checked_pow
68
else
7-
using Base: BitInteger, throw_domerr_powbysq, to_power_type
8-
using Base.Checked: mul_with_overflow, throw_overflowerr_binaryop
9+
using Base: throw_domerr_powbysq, to_power_type
10+
using Base.Checked: throw_overflowerr_binaryop
911
end
1012

1113
# The Base methods have unchecked semantics, so just pass through
@@ -22,13 +24,22 @@ checked_add(a, b, c, xs...) = @checked (@_inline_meta; afoldl(+, (+)((+)(a, b),
2224
checked_sub(a, b, c, xs...) = @checked (@_inline_meta; afoldl(-, (-)((-)(a, b), c), xs...))
2325
checked_mul(a, b, c, xs...) = @checked (@_inline_meta; afoldl(*, (*)((*)(a, b), c), xs...))
2426

27+
saturating_add(a, b, c, xs...) = @saturating (@_inline_meta; afoldl(+, (+)((+)(a, b), c), xs...))
28+
saturating_sub(a, b, c, xs...) = @saturating (@_inline_meta; afoldl(-, (-)((-)(a, b), c), xs...))
29+
saturating_mul(a, b, c, xs...) = @saturating (@_inline_meta; afoldl(*, (*)((*)(a, b), c), xs...))
30+
2531

2632
# promote unmatched number types to same type
2733
checked_add(x::Number, y::Number) = checked_add(promote(x, y)...)
2834
checked_sub(x::Number, y::Number) = checked_sub(promote(x, y)...)
2935
checked_mul(x::Number, y::Number) = checked_mul(promote(x, y)...)
3036
checked_pow(x::Number, y::Number) = checked_pow(promote(x, y)...)
3137

38+
saturating_add(x::Number, y::Number) = saturating_add(promote(x, y)...)
39+
saturating_sub(x::Number, y::Number) = saturating_sub(promote(x, y)...)
40+
saturating_mul(x::Number, y::Number) = saturating_mul(promote(x, y)...)
41+
saturating_pow(x::Number, y::Number) = saturating_pow(promote(x, y)...)
42+
3243

3344
# fallback to `unchecked_` for `Number` types that don't have more specific `checked_` methods
3445
checked_neg(x::T) where T <: Number = unchecked_neg(x)
@@ -38,6 +49,13 @@ checked_mul(x::T, y::T) where T <: Number = unchecked_mul(x, y)
3849
checked_pow(x::T, y::T) where T <: Number = unchecked_pow(x, y)
3950
checked_abs(x::T) where T <: Number = unchecked_abs(x)
4051

52+
saturating_neg(x::T) where T <: Number = unchecked_neg(x)
53+
saturating_add(x::T, y::T) where T <: Number = unchecked_add(x, y)
54+
saturating_sub(x::T, y::T) where T <: Number = unchecked_sub(x, y)
55+
saturating_mul(x::T, y::T) where T <: Number = unchecked_mul(x, y)
56+
saturating_pow(x::T, y::T) where T <: Number = unchecked_pow(x, y)
57+
saturating_abs(x::T) where T <: Number = unchecked_abs(x)
58+
4159

4260
# fallback to `unchecked_` for non-`Number` types
4361
checked_neg(x) = unchecked_neg(x)
@@ -51,50 +69,38 @@ checked_abs(x) = unchecked_abs(x)
5169
if VERSION < v"1.11"
5270
# Base.Checked only gained checked powers in 1.11
5371

54-
function checked_pow(x::T, y::S) where {T <: BitInteger, S <: BitInteger}
55-
@_inline_meta
56-
z, b = pow_with_overflow(x, y)
57-
b && throw_overflowerr_binaryop(:^, x, y)
58-
z
59-
end
72+
checked_pow(x_::T, p::S) where {T <: BitInteger, S <: BitInteger} =
73+
power_by_squaring(x_, p; mul = checked_mul)
6074

61-
function pow_with_overflow(x_, p::Integer)
75+
# Base.@assume_effects :terminates_locally # present in Julia 1.11 code, but only supported from 1.8 on
76+
function power_by_squaring(x_, p::Integer; mul=*)
6277
x = to_power_type(x_)
6378
if p == 1
64-
return (copy(x), false)
79+
return copy(x)
6580
elseif p == 0
66-
return (one(x), false)
81+
return one(x)
6782
elseif p == 2
68-
return mul_with_overflow(x, x)
83+
return mul(x, x)
6984
elseif p < 0
70-
isone(x) && return (copy(x), false)
71-
isone(-x) && return (iseven(p) ? one(x) : copy(x), false)
85+
isone(x) && return copy(x)
86+
isone(-x) && return iseven(p) ? one(x) : copy(x)
7287
throw_domerr_powbysq(x, p)
7388
end
7489
t = trailing_zeros(p) + 1
7590
p >>= t
76-
b = false
7791
while (t -= 1) > 0
78-
x, b1 = mul_with_overflow(x, x)
79-
b |= b1
92+
x = mul(x, x)
8093
end
8194
y = x
8295
while p > 0
8396
t = trailing_zeros(p) + 1
8497
p >>= t
8598
while (t -= 1) >= 0
86-
x, b1 = mul_with_overflow(x, x)
87-
b |= b1
99+
x = mul(x, x)
88100
end
89-
y, b1 = mul_with_overflow(y, x)
90-
b |= b1
101+
y = mul(y, x)
91102
end
92-
return y, b
93-
end
94-
pow_with_overflow(x::Bool, p::Unsigned) = ((p==0) | x, false)
95-
function pow_with_overflow(x::Bool, p::Integer)
96-
p < 0 && !x && throw_domerr_powbysq(x, p)
97-
return (p==0) | x, false
103+
return y
98104
end
99105

100106
end

src/base_ext_sat.jl

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
import Base: BitInteger
2+
import Base.Checked: mul_with_overflow
3+
4+
if VERSION v"1.11-alpha"
5+
import Base: power_by_squaring
6+
end
7+
8+
# saturating implementations
9+
const SignedBitInteger = Union{Int8, Int16, Int32, Int64, Int128}
10+
11+
saturating_neg(x::T) where T <: BitInteger = saturating_sub(zero(T), x)
12+
13+
if VERSION v"1.5"
14+
using Base: llvmcall
15+
16+
# These intrinsics were added in LLVM 8, which was first supported with Julia 1.5
17+
@generated function saturating_add(x::T, y::T) where T <: BitInteger
18+
llvm_su = T <: Signed ? "s" : "u"
19+
llvm_t = "i" * string(8sizeof(T))
20+
llvm_intrinsic = "llvm.$(llvm_su)add.sat.$llvm_t"
21+
:(ccall($llvm_intrinsic, llvmcall, $T, ($T, $T), x, y))
22+
end
23+
24+
@generated function saturating_sub(x::T, y::T) where T <: BitInteger
25+
llvm_su = T <: Signed ? "s" : "u"
26+
llvm_t = "i" * string(8sizeof(T))
27+
llvm_intrinsic = "llvm.$(llvm_su)sub.sat.$llvm_t"
28+
:(ccall($llvm_intrinsic, llvmcall, $T, ($T, $T), x, y))
29+
end
30+
31+
else
32+
import Base.Checked: add_with_overflow, sub_with_overflow
33+
34+
function saturating_add(x::T, y::T) where T <: BitInteger
35+
result, overflow_flag = add_with_overflow(x, y)
36+
if overflow_flag
37+
return sign(x) > 0 ?
38+
typemax(T) :
39+
typemin(T)
40+
end
41+
return result
42+
end
43+
44+
function saturating_sub(x::T, y::T) where T <: BitInteger
45+
result, overflow_flag = sub_with_overflow(x, y)
46+
if overflow_flag
47+
return y > x ?
48+
typemin(T) :
49+
typemax(T)
50+
end
51+
return result
52+
end
53+
end
54+
55+
function saturating_mul(x::T, y::T) where T <: BitInteger
56+
result, overflow_flag = mul_with_overflow(x, y)
57+
return overflow_flag ?
58+
(sign(x) == sign(y) ?
59+
typemax(T) :
60+
typemin(T)) :
61+
result
62+
end
63+
64+
saturating_pow(x_::T, p::S) where {T <: BitInteger, S <: BitInteger} =
65+
power_by_squaring(x_, p; mul = saturating_mul)
66+
67+
function saturating_abs(x::T) where T <: SignedBitInteger
68+
result = flipsign(x, x)
69+
return result < 0 ? typemax(T) : result
70+
end

src/macros.jl

Lines changed: 52 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,41 @@ macro default_unchecked()
5050
end
5151
end
5252

53+
"""
54+
@default_saturating
55+
56+
Redirect default integer math to saturating operators for the current module. Only works at top-level.
57+
"""
58+
macro default_saturating()
59+
quote
60+
if !isdefined(@__MODULE__, :__OverflowContextDefaultSet)
61+
any(Base.isbindingresolved.(Ref(@__MODULE__), op_method_symbols)) &&
62+
error("A default context may only be set before any reference to the affected methods (+, -, *, ^, abs) in the target module.")
63+
else
64+
@warn "A previous default was set for this module. Previously defined methods in this module will be recompiled with this new default."
65+
end
66+
(@__MODULE__).eval(:(-(x) = OverflowContexts.saturating_neg(x)))
67+
(@__MODULE__).eval(:(+(x...) = OverflowContexts.saturating_add(x...)))
68+
(@__MODULE__).eval(:(-(x...) = OverflowContexts.saturating_sub(x...)))
69+
(@__MODULE__).eval(:(*(x...) = OverflowContexts.saturating_mul(x...)))
70+
(@__MODULE__).eval(:(^(x...) = OverflowContexts.saturating_pow(x...)))
71+
(@__MODULE__).eval(:(abs(x) = OverflowContexts.saturating_abs(x)))
72+
(@__MODULE__).eval(:(__OverflowContextDefaultSet = true))
73+
nothing
74+
end
75+
end
76+
77+
"""
78+
@checked expr
79+
80+
Perform all integer operations in `expr` using overflow-checked arithmetic.
81+
"""
82+
macro checked(expr)
83+
isa(expr, Expr) || return expr
84+
expr = copy(expr)
85+
return esc(replace_op!(expr, op_checked))
86+
end
87+
5388
"""
5489
@unchecked expr
5590
@@ -62,14 +97,14 @@ macro unchecked(expr)
6297
end
6398

6499
"""
65-
@checked expr
100+
@saturating expr
66101
67-
Perform all integer operations in `expr` using overflow-checked arithmetic.
102+
Perform all integer operations in `expr` using saturating arithmetic.
68103
"""
69-
macro checked(expr)
104+
macro saturating(expr)
70105
isa(expr, Expr) || return expr
71106
expr = copy(expr)
72-
return esc(replace_op!(expr, op_checked))
107+
return esc(replace_op!(expr, op_saturating))
73108
end
74109

75110
const op_checked = Dict(
@@ -92,6 +127,16 @@ const op_unchecked = Dict(
92127
:abs => :(unchecked_abs)
93128
)
94129

130+
const op_saturating = Dict(
131+
Symbol("unary-") => :(saturating_neg),
132+
Symbol("ambig-") => :(saturating_negsub),
133+
:+ => :(saturating_add),
134+
:- => :(saturating_sub),
135+
:* => :(saturating_mul),
136+
:^ => :(saturating_pow),
137+
:abs => :(saturating_abs)
138+
)
139+
95140
const broadcast_op_map = Dict(
96141
:.+ => :+,
97142
:.- => :-,
@@ -115,6 +160,8 @@ unchecked_negsub(x) = unchecked_neg(x)
115160
unchecked_negsub(x, y) = unchecked_sub(x, y)
116161
checked_negsub(x) = checked_neg(x)
117162
checked_negsub(x, y) = checked_sub(x, y)
163+
saturating_negsub(x) = saturating_neg(x)
164+
saturating_negsub(x, y) = saturating_sub(x, y)
118165

119166
# copied from CheckedArithmetic.jl and modified it
120167
function replace_op!(expr::Expr, op_map::Dict)
@@ -182,7 +229,7 @@ function replace_op!(expr::Expr, op_map::Dict)
182229
elseif isexpr(expr, :.) # broadcast function
183230
op = expr.args[1]
184231
expr.args[1] = get(op_map, op, op)
185-
elseif !isexpr(expr, :macrocall) || expr.args[1] (Symbol("@checked"), Symbol("@unchecked"))
232+
elseif !isexpr(expr, :macrocall) || expr.args[1] (Symbol("@checked"), Symbol("@unchecked"), Symbol("@saturating"))
186233
for a in expr.args
187234
if isa(a, Expr)
188235
replace_op!(a, op_map)

0 commit comments

Comments
 (0)