Skip to content

Commit a88a8bc

Browse files
committed
Add some more power specializations
1 parent 011f8d3 commit a88a8bc

File tree

1 file changed

+66
-47
lines changed

1 file changed

+66
-47
lines changed

src/parse/add_compute.jl

Lines changed: 66 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -452,59 +452,78 @@ end
452452
function add_pow!(
453453
ls::LoopSet, var::Symbol, @nospecialize(x), p::Real, elementbytes::Int, position::Int
454454
)
455-
xop::Operation = if x isa Expr
456-
add_operation!(ls, Symbol("###xpow###$(length(operations(ls)))###"), x, elementbytes, position)
457-
elseif x isa Symbol
458-
if x ls.loopsymbols
459-
add_loopvalue!(ls, x, elementbytes)
455+
xop::Operation = if x isa Expr
456+
add_operation!(ls, Symbol("###xpow###$(length(operations(ls)))###"), x, elementbytes, position)
457+
elseif x isa Symbol
458+
if x ls.loopsymbols
459+
add_loopvalue!(ls, x, elementbytes)
460+
else
461+
xo = get(ls.opdict, x, nothing)
462+
if xo === nothing
463+
if round(Int,p) p
464+
pushpreamble!(ls, Expr(:(=), var, Expr(:call, :(^), x, p)))
465+
return add_constant!(ls, var, elementbytes)
460466
else
461-
xo = get(ls.opdict, x, nothing)
462-
if xo === nothing
463-
pushpreamble!(ls, Expr(:(=), var, Expr(:call, :(^), x, p)))
464-
return add_constant!(ls, var, elementbytes)
465-
end
466-
xo
467+
xo = add_constant!(ls, x, elementbytes)::Operation
467468
end
468-
elseif x isa Number
469-
return add_constant!(ls, x ^ p, elementbytes, var)::Operation
470-
end
471-
pint = round(Int, p)
472-
if p != pint
473-
pop = add_constant!(ls, p, elementbytes)
474-
return add_compute!(ls, var, :^, [xop, pop], elementbytes)
475-
end
476-
if pint == -1
477-
return add_compute!(ls, var, :inv, [xop], elementbytes)
478-
elseif pint < 0
479-
xop = add_compute!(ls, gensym!(ls, "inverse"), :inv, [xop], elementbytes)
480-
pint = - pint
481-
end
482-
if pint == 0
483-
op = Operation(length(operations(ls)), var, elementbytes, LOOPCONSTANT, constant, NODEPENDENCY, Symbol[], NOPARENTS)
484-
push!(ls.preamble_funcofeltypes, (identifier(op),MULTIPLICATIVE_IN_REDUCTIONS))
485-
return pushop!(ls, op)
486-
elseif pint == 1
487-
return add_compute!(ls, var, :identity, [xop], elementbytes)
488-
elseif pint == 2
489-
return add_compute!(ls, var, :abs2_fast, [xop], elementbytes)
469+
end
470+
xo
490471
end
472+
elseif x isa Number
473+
return add_constant!(ls, x ^ p, elementbytes, var)::Operation
474+
end
475+
pint = round(Int, p)
476+
if pint == -1
477+
return add_compute!(ls, var, :inv, [xop], elementbytes)
478+
elseif pint < 0
479+
xop = add_compute!(ls, gensym!(ls, "inverse"), :inv, [xop], elementbytes)
480+
p = -p
481+
pint = - pint
482+
end
483+
if p == 0.5
484+
return add_compute!(ls, var, :sqrt, [xop], elementbytes)
485+
elseif p == 1/3
486+
return add_compute!(ls, var, :cbrt, [xop], elementbytes)
487+
elseif p == 2/3
488+
xop = add_compute!(ls, gensym!(ls, "cbrt"), :cbrt, [xop], elementbytes)
489+
return add_compute!(ls, var, :abs2_fast, [xop], elementbytes)
490+
elseif p == 0.75
491+
xop = add_compute!(ls, gensym!(ls, "root1"), :sqrt, [xop], elementbytes)
492+
xop = add_compute!(ls, gensym!(ls, "root2"), :sqrt, [xop], elementbytes)
493+
pint = 3
494+
elseif p == 0.25
495+
xop = add_compute!(ls, gensym!(ls, "root1"), :sqrt, [xop], elementbytes)
496+
return add_compute!(ls, var, :sqrt, [xop], elementbytes)
497+
elseif p != pint
498+
pop = add_constant!(ls, p, elementbytes)
499+
return add_compute!(ls, var, :^, [xop, pop], elementbytes)
500+
end
501+
if pint == 0
502+
op = Operation(length(operations(ls)), var, elementbytes, LOOPCONSTANT, constant, NODEPENDENCY, Symbol[], NOPARENTS)
503+
push!(ls.preamble_funcofeltypes, (identifier(op),MULTIPLICATIVE_IN_REDUCTIONS))
504+
return pushop!(ls, op)
505+
elseif pint == 1
506+
return add_compute!(ls, var, :identity, [xop], elementbytes)
507+
elseif pint == 2
508+
return add_compute!(ls, var, :abs2_fast, [xop], elementbytes)
509+
end
491510

492-
# Implementation from https://github.com/JuliaLang/julia/blob/a965580ba7fd0e8314001521df254e30d686afbf/base/intfuncs.jl#L216
511+
# Implementation from https://github.com/JuliaLang/julia/blob/a965580ba7fd0e8314001521df254e30d686afbf/base/intfuncs.jl#L216
512+
t = trailing_zeros(pint) + 1
513+
pint >>= t
514+
while (t -= 1) > 0
515+
varname = (iszero(pint) && isone(t)) ? var : gensym!(ls, "pbs")
516+
xop = add_compute!(ls, varname, :abs2_fast, [xop], elementbytes)
517+
end
518+
yop = xop
519+
while pint > 0
493520
t = trailing_zeros(pint) + 1
494521
pint >>= t
495-
while (t -= 1) > 0
496-
varname = (iszero(pint) && isone(t)) ? var : gensym!(ls, "pbs")
497-
xop = add_compute!(ls, varname, :abs2_fast, [xop], elementbytes)
522+
while (t -= 1) >= 0
523+
xop = add_compute!(ls, gensym!(ls, "pbs"), :abs2_fast, [xop], elementbytes)
498524
end
499-
yop = xop
500-
while pint > 0
501-
t = trailing_zeros(pint) + 1
502-
pint >>= t
503-
while (t -= 1) >= 0
504-
xop = add_compute!(ls, gensym!(ls, "pbs"), :abs2_fast, [xop], elementbytes)
505-
end
506-
yop = add_compute!(ls, iszero(pint) ? var : gensym!(ls, "pbs"), :mul_fast, [xop, yop], elementbytes)
507-
end
508-
yop
525+
yop = add_compute!(ls, iszero(pint) ? var : gensym!(ls, "pbs"), :mul_fast, [xop, yop], elementbytes)
526+
end
527+
yop
509528
end
510529

0 commit comments

Comments
 (0)