Skip to content

Commit e3faaa9

Browse files
committed
use rationalize for literal powers, test
1 parent b3bb5b4 commit e3faaa9

File tree

2 files changed

+52
-24
lines changed

2 files changed

+52
-24
lines changed

src/parse/add_compute.jl

Lines changed: 38 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -472,42 +472,57 @@ function add_pow!(
472472
elseif x isa Number
473473
return add_constant!(ls, x ^ p, elementbytes, var)::Operation
474474
end
475-
pint = round(Int, p)
475+
local pnum::Int, pden::Int
476+
if p isa Integer
477+
pnum = Int(p)::Int
478+
pden = 1
479+
else
480+
prational = rationalize(p)
481+
@unpack num, den = prational
482+
pnum = convert(Int,num)::Int
483+
pden = convert(Int,den)::Int
484+
end
485+
if pden == 1
486+
nothing
487+
elseif pden == 2
488+
if pnum == 1
489+
return add_compute!(ls, var, :sqrt, [xop], elementbytes)
490+
else
491+
xop = add_compute!(ls, gensym!(ls,"root"), :sqrt, [xop], elementbytes)
492+
end
493+
elseif pden == 3
494+
if pnum == 1
495+
return add_compute!(ls, var, :cbrt, [xop], elementbytes)
496+
else
497+
xop = add_compute!(ls, gensym!(ls,"cbroot"), :cbrt, [xop], elementbytes)
498+
end
499+
elseif pden == 4
500+
xop = add_compute!(ls, gensym!(ls,"root"), :sqrt, [xop], elementbytes)
501+
if pnum == 1
502+
return add_compute!(ls, var, :sqrt, [xop], elementbytes)
503+
else
504+
xop = add_compute!(ls, gensym!(ls,"root"), :sqrt, [xop], elementbytes)
505+
end
506+
else
507+
pop = add_constant!(ls, p, elementbytes)
508+
return add_compute!(ls, var, :^, [xop, pop], elementbytes)
509+
end
510+
pint = pnum
476511
if pint == -1
477512
return add_compute!(ls, var, :inv, [xop], elementbytes)
478513
elseif pint < 0
479514
xop = add_compute!(ls, gensym!(ls, "inverse"), :inv, [xop], elementbytes)
480-
p = -p
481-
pint = - pint
482-
end
483-
if p == 0.5
484-
return add_compute!(ls, var, :sqrt, [xop], elementbytes)
485-
elseif p == 1/3
486-
return add_compute!(ls, var, :cbrt, [xop], elementbytes)
487-
elseif p == 2/3
488-
xop = add_compute!(ls, gensym!(ls, "cbrt"), :cbrt, [xop], elementbytes)
489-
return add_compute!(ls, var, :abs2_fast, [xop], elementbytes)
490-
elseif p == 0.75
491-
xop = add_compute!(ls, gensym!(ls, "root1"), :sqrt, [xop], elementbytes)
492-
xop = add_compute!(ls, gensym!(ls, "root2"), :sqrt, [xop], elementbytes)
493-
pint = 3
494-
elseif p == 0.25
495-
xop = add_compute!(ls, gensym!(ls, "root1"), :sqrt, [xop], elementbytes)
496-
return add_compute!(ls, var, :sqrt, [xop], elementbytes)
497-
elseif p != pint
498-
pop = add_constant!(ls, p, elementbytes)
499-
return add_compute!(ls, var, :^, [xop, pop], elementbytes)
515+
pint = -pint
500516
end
501517
if pint == 0
502518
op = Operation(length(operations(ls)), var, elementbytes, LOOPCONSTANT, constant, NODEPENDENCY, Symbol[], NOPARENTS)
503519
push!(ls.preamble_funcofeltypes, (identifier(op),MULTIPLICATIVE_IN_REDUCTIONS))
504520
return pushop!(ls, op)
505-
elseif pint == 1
521+
elseif pint == 1#requires `pden ≠ 1`.
506522
return add_compute!(ls, var, :identity, [xop], elementbytes)
507523
elseif pint == 2
508524
return add_compute!(ls, var, :abs2_fast, [xop], elementbytes)
509525
end
510-
511526
# Implementation from https://github.com/JuliaLang/julia/blob/a965580ba7fd0e8314001521df254e30d686afbf/base/intfuncs.jl#L216
512527
t = trailing_zeros(pint) + 1
513528
pint >>= t

test/special.jl

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,14 @@
263263
y[i] = x[i] ^ p[i]
264264
end; y
265265
end
266+
@generated function vpow!(y, x, ::Val{p}) where {p}
267+
quote
268+
@turbo for i eachindex(y,x)
269+
y[i] = x[i] ^ $p
270+
end
271+
return y
272+
end
273+
end
266274

267275
function csetanh!(y, z, x)
268276
for j in axes(x, 2)
@@ -427,7 +435,12 @@
427435
@test vpowf!(r1, x) (r2 .= x .^ 2.3)
428436
@test vpowf!(r1, x, -1.7) (r2 .= x .^ -1.7)
429437
p = randn(length(x));
430-
@test vpowf!(r1, x, x) (r2 .= x .^ x)
438+
@test vpowf!(r1, x, x) (r2 .= x .^ x)
439+
@test vpow!(r1, x, Val(0.75)) (r2 .= x .^ 0.75)
440+
@test vpow!(r1, x, Val(2/3)) (r2 .= x .^ (2/3))
441+
@test vpow!(r1, x, Val(0.5)) == (r2 .= sqrt.(x))
442+
@test vpow!(r1, x, Val(1/4)) (r2 .= x .^ (1/4))
443+
@test vpow!(r1, x, Val(4.5)) (r2 .= x .^ 4.5)
431444

432445
X = rand(T, N, M); Z = rand(T, N, M);
433446
Y1 = similar(X); Y2 = similar(Y1);

0 commit comments

Comments
 (0)