Skip to content

Commit 7aa1d39

Browse files
authored
Add literal_pow transformation (#19)
1 parent 117a93d commit 7aa1d39

File tree

6 files changed

+61
-20
lines changed

6 files changed

+61
-20
lines changed

src/OverflowContexts.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,20 @@ module OverflowContexts
33
const SignedBitInteger = Union{Int8, Int16, Int32, Int64, Int128}
44
const UnsignedBitInteger = Union{UInt8, UInt16, UInt32, UInt64, UInt128}
55

6+
using Base: BitInteger, promote, afoldl, @_inline_meta
7+
import Base: literal_pow
8+
import Base.Checked: checked_neg, checked_add, checked_sub, checked_mul, checked_abs,
9+
checked_div, checked_fld, checked_cld, checked_mod, checked_rem
10+
using Base.Checked: mul_with_overflow
11+
12+
if VERSION v"1.11-alpha"
13+
import Base: power_by_squaring
14+
import Base.Checked: checked_pow
15+
else
16+
using Base: throw_domerr_powbysq, to_power_type
17+
using Base.Checked: throw_overflowerr_binaryop
18+
end
19+
620
include("macros.jl")
721
include("checked.jl")
822
include("unchecked.jl")

src/checked.jl

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,3 @@
1-
using Base: BitInteger, promote, afoldl, @_inline_meta
2-
import Base.Checked: checked_neg, checked_add, checked_sub, checked_mul, checked_abs,
3-
checked_div, checked_fld, checked_cld, checked_mod, checked_rem
4-
using Base.Checked: mul_with_overflow
5-
6-
if VERSION v"1.11-alpha"
7-
import Base: power_by_squaring
8-
import Base.Checked: checked_pow
9-
else
10-
using Base: throw_domerr_powbysq, to_power_type
11-
using Base.Checked: throw_overflowerr_binaryop
12-
end
13-
141
# resolve ambiguity when `-` used as symbol
152
checked_negsub(x) = checked_neg(x)
163
checked_negsub(x, y) = checked_sub(x, y)
@@ -87,3 +74,19 @@ if VERSION < v"1.11"
8774
return y
8875
end
8976
end
77+
78+
# adapted from Base intfuncs.jl; negative literal powers promote to floating point
79+
@inline literal_pow(::typeof(checked_pow), x::BitInteger, ::Val{0}) = one(x)
80+
@inline literal_pow(::typeof(checked_pow), x::BitInteger, ::Val{1}) = x
81+
@inline literal_pow(::typeof(checked_pow), x::BitInteger, ::Val{2}) = @checked x * x
82+
@inline literal_pow(::typeof(checked_pow), x::BitInteger, ::Val{3}) = @checked x * x * x
83+
@inline literal_pow(::typeof(checked_pow), x::BitInteger, ::Val{-1}) = literal_pow(^, x, Val(-1))
84+
@inline literal_pow(::typeof(checked_pow), x::BitInteger, ::Val{-2}) = literal_pow(^, x, Val(-2))
85+
86+
@inline function literal_pow(f::typeof(checked_pow), x, ::Val{p}) where {p}
87+
if p < 0
88+
literal_pow(^, x, Val(p))
89+
else
90+
f(x, p)
91+
end
92+
end

src/macros.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,12 +235,18 @@ function replace_op!(expr::Expr, op_map::Dict)
235235
Expr(:tuple, expr.args[2:end]...)]
236236
end
237237
else # arbitrary call
238+
op_orig = op
238239
op = get(op_map, op, op)
239240
if isexpr(f, :.)
240241
f.args[2] = QuoteNode(op)
241242
expr.args[1] = f
242243
else
243244
expr.args[1] = op
245+
if op_orig == :^ && expr.args[3] isa Integer
246+
# literal_pow transformation
247+
pushfirst!(expr.args, :(Base.literal_pow))
248+
expr.args[4] = :(Val($(expr.args[4])))
249+
end
244250
end
245251
end
246252
for i in 2:length(expr.args)

src/saturating.jl

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,3 @@
1-
import Base: BitInteger
2-
import Base.Checked: mul_with_overflow
3-
4-
if VERSION v"1.11-alpha"
5-
using Base: power_by_squaring
6-
end
7-
81
# resolve ambiguity when `-` used as symbol
92
saturating_negsub(x) = saturating_neg(x)
103
saturating_negsub(x, y) = saturating_sub(x, y)
@@ -156,3 +149,19 @@ function saturating_mod(x::T, y::T) where T <: SignedBitInteger
156149
end
157150

158151
saturating_mod(x::T, y::T) where T <: UnsignedBitInteger = @saturating rem(x, y)
152+
153+
# adapted from Base intfuncs.jl; negative literal powers promote to floating point
154+
@inline literal_pow(::typeof(saturating_pow), x::BitInteger, ::Val{0}) = one(x)
155+
@inline literal_pow(::typeof(saturating_pow), x::BitInteger, ::Val{1}) = x
156+
@inline literal_pow(::typeof(saturating_pow), x::BitInteger, ::Val{2}) = @saturating x * x
157+
@inline literal_pow(::typeof(saturating_pow), x::BitInteger, ::Val{3}) = @saturating x * x * x
158+
@inline literal_pow(::typeof(saturating_pow), x::BitInteger, ::Val{-1}) = literal_pow(^, x, Val(-1))
159+
@inline literal_pow(::typeof(saturating_pow), x::BitInteger, ::Val{-2}) = literal_pow(^, x, Val(-2))
160+
161+
@inline function literal_pow(f::typeof(saturating_pow), x, ::Val{p}) where {p}
162+
if p < 0
163+
literal_pow(^, x, Val(p))
164+
else
165+
f(x, p)
166+
end
167+
end

src/unchecked.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,3 +67,6 @@ unchecked_rem(x::T, y::T) where T <: UnsignedBitInteger =
6767

6868
unchecked_mod(x::T, y::T) where T <: SignedBitInteger = x - unchecked_fld(x, y) * y
6969
unchecked_mod(x::T, y::T) where T <: UnsignedBitInteger = unchecked_rem(x, y)
70+
71+
# adapted from Base intfuncs.jl; negative literal powers promote to floating point
72+
@inline literal_pow(::typeof(unchecked_pow), x, ::Val{p}) where {p} = literal_pow(^, x, Val(p))

test/runtests.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -792,3 +792,9 @@ end
792792
@test_throws ErrorException @saturating aa * bb'
793793
@test_throws ErrorException @saturating dd ^ 2
794794
end
795+
796+
@testset "literal_pow transformation" begin
797+
expr = @macroexpand @checked 5 ^ 2
798+
@test expr.args[1] == :(Base.literal_pow)
799+
@test expr.args[2] == :(OverflowContexts.checked_pow)
800+
end

0 commit comments

Comments
 (0)