Skip to content

Commit 8e18c5b

Browse files
committed
Faster math, primarily in broadcasts.
1 parent 4aa57a2 commit 8e18c5b

File tree

3 files changed

+55
-86
lines changed

3 files changed

+55
-86
lines changed

src/LoopVectorization.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,8 @@ using Base.Meta: isexpr
3333
using DocStringExtensions
3434
import LinearAlgebra # for check_args
3535

36-
using Base.FastMath: add_fast, sub_fast, mul_fast, div_fast, inv_fast, abs2_fast, rem_fast, max_fast, min_fast, pow_fast
37-
using SLEEFPirates: log_fast, log2_fast, log10_fast, pow
36+
using Base.FastMath: add_fast, sub_fast, mul_fast, div_fast, inv_fast, abs2_fast, rem_fast, max_fast, min_fast, pow_fast, sqrt_fast
37+
using SLEEFPirates: log_fast, log2_fast, log10_fast, pow, sin_fast, cos_fast, sincos_fast
3838

3939
using ArrayInterface
4040
using ArrayInterface: OptionallyStaticUnitRange, OptionallyStaticRange, Zero, One, StaticBool, True, False, reduce_tup, indices, UpTri, LoTri

src/modeling/costs.jl

Lines changed: 45 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -207,23 +207,6 @@ const COST = Dict{Symbol,InstructionCost}(
207207
:vfmsub231 => InstructionCost(4,0.5), # - and * will fuse into this, so much of the time they're not twice as expensive
208208
:vfnmadd231 => InstructionCost(4,0.5), # + and -* will fuse into this, so much of the time they're not twice as expensive
209209
:vfnmsub231 => InstructionCost(4,0.5), # - and -* will fuse into this, so much of the time they're not twice as expensive
210-
# :vfmadd! => InstructionCost(4,0.5), # + and * will fuse into this, so much of the time they're not twice as expensive
211-
# :vfnmadd! => InstructionCost(4,0.5), # + and -* will fuse into this, so much of the time they're not twice as expensive
212-
# :vfmsub! => InstructionCost(4,0.5), # + and * will fuse into this, so much of the time they're not twice as expensive
213-
# :vfnmsub! => InstructionCost(4,0.5), # + and -* will fuse into this, so much of the time they're not twice as expensive
214-
# :vfmadd_fast => InstructionCost(4,0.5), # + and * will fuse into this, so much of the time they're not twice as expensive
215-
# :vfmsub_fast => InstructionCost(4,0.5), # - and * will fuse into this, so much of the time they're not twice as expensive
216-
# :vfnmadd_fast => InstructionCost(4,0.5), # + and -* will fuse into this, so much of the time they're not twice as expensive
217-
# :vfnmsub_fast => InstructionCost(4,0.5), # - and -* will fuse into this, so much of the time they're not twice as expensive
218-
# :vfmaddaddone => InstructionCost(4,0.5), # - and -* will fuse into this, so much of the time they're not twice as expensive
219-
# :vmullog2 => InstructionCost(4,0.5),
220-
# :vmullog2add! => InstructionCost(4,0.5),
221-
# :vmullog10 => InstructionCost(4,0.5),
222-
# :vmullog10add! => InstructionCost(4,0.5),
223-
# :vdivlog2 => InstructionCost(13,4.0,-2.0),
224-
# :vdivlog2add! =>InstructionCost(13,4.0,-2.0),
225-
# :vdivlog10 => InstructionCost(13,4.0,-2.0),
226-
# :vdivlog10add! =>InstructionCost(13,4.0,-2.0),
227210
:sqrt => InstructionCost(15,4.0,-2.0),
228211
:sqrt_fast => InstructionCost(15,4.0,-2.0),
229212
:log => InstructionCost(-3.0, 15, 30, 11),
@@ -242,14 +225,14 @@ const COST = Dict{Symbol,InstructionCost}(
242225
:sin => InstructionCost(-3,30.0,60.0,23),
243226
:cos => InstructionCost(-3,27.0,60.0,26),
244227
:sincos => InstructionCost(-3,37.0,85.0,26),
228+
:sincos_fast => InstructionCost(-3,37.0,85.0,26),
245229
:sinpi => InstructionCost(18,15.0,68.0,23),
246230
:cospi => InstructionCost(18,15.0,68.0,26),
247231
:sincospi => InstructionCost(25,37.0,70.0,26),
248232
:log_fast => InstructionCost(20,20.0,40.0,20),
249233
:exp_fast => InstructionCost(20,20.0,20.0,18),
250234
:sin_fast => InstructionCost(18,15.0,68.0,23),
251235
:cos_fast => InstructionCost(18,15.0,68.0,26),
252-
:sincos => InstructionCost(-3,37.0,85.0,26),
253236
:sinpi_fast => InstructionCost(18,15.0,68.0,23),
254237
:cospi_fast => InstructionCost(18,15.0,68.0,26),
255238
:sincospi_fast => InstructionCost(25,22.0,70.0,26),
@@ -502,35 +485,29 @@ end
502485
isreductcombineinstr(instr::Instruction) = isreductcombineinstr(instr.instr)
503486

504487
const FUNCTIONSYMBOLS = IdDict{Type{<:Function},Instruction}(
505-
typeof(+) => :(+),
506-
typeof(VectorizationBase.vadd) => :(+),
507-
# typeof(VectorizationBase.vadd!) => :(+),
508-
typeof(Base.FastMath.add_fast) => :add_fast,
509-
typeof(-) => :(-),
510-
typeof(VectorizationBase.vsub) => :(-),
511-
# typeof(VectorizationBase.vsub!) => :(-),
512-
typeof(Base.FastMath.sub_fast) => :sub_fast,
513-
typeof(*) => :(*),
514-
typeof(VectorizationBase.vmul) => :(*),
515-
# typeof(VectorizationBase.vmul!) => :(*),
516-
typeof(Base.FastMath.mul_fast) => :mul_fast,
517-
typeof(/) => :(/),
518-
typeof(^) => :(^),
519-
# typeof(VectorizationBase.vfdiv) => :(/),
520-
# typeof(VectorizationBase.vfdiv!) => :(/),
521-
typeof(VectorizationBase.vdiv) => :(/),
488+
typeof(+) => :add_fast,
489+
typeof(VectorizationBase.vadd) => :add_fast,
490+
typeof(add_fast) => :add_fast,
491+
typeof(-) => :sub_fast,
492+
typeof(VectorizationBase.vsub) => :sub_fast,
493+
typeof(sub_fast) => :sub_fast,
494+
typeof(*) => :mul_fast,
495+
typeof(VectorizationBase.vmul) => :mul_fast,
496+
typeof(mul_fast) => :mul_fast,
497+
typeof(/) => :div_fast,
498+
typeof(^) => :pow_fast,
499+
typeof(VectorizationBase.vdiv) => :div_fast,
522500
typeof(÷) => :(÷),
523-
typeof(Base.FastMath.div_fast) => :div_fast,
524-
typeof(Base.FastMath.div_fast) => :div_fast,
525-
typeof(Base.FastMath.rem_fast) => :rem_fast,
501+
typeof(div_fast) => :div_fast,
502+
typeof(rem_fast) => :rem_fast,
526503
typeof(==) => :(==),
527504
typeof(!=) => :(!=),
528505
typeof(isequal) => :isequal,
529506
typeof(isnan) => :isnan,
530507
typeof(isinf) => :isinf,
531508
typeof(isfinite) => :isfinite,
532509
typeof(abs) => :abs,
533-
typeof(abs2) => :abs2,
510+
typeof(abs2) => :abs2_fast,
534511
typeof(abs2_fast) => :abs2_fast,
535512
typeof(~) => :(~),
536513
typeof(!) => :(!),
@@ -541,15 +518,15 @@ const FUNCTIONSYMBOLS = IdDict{Type{<:Function},Instruction}(
541518
typeof(<) => :(<),
542519
typeof(>=) => :(>=),
543520
typeof(<=) => :(<=),
544-
typeof(inv) => :inv,
521+
typeof(inv) => :inv_fast,
545522
typeof(inv_fast) => :inv_fast,
546-
typeof(muladd) => :muladd,
547-
typeof(fma) => :fma,
548-
typeof(VectorizationBase.vfma) => :vfma,
549-
typeof(VectorizationBase.vmuladd) => :vmuladd,
550-
typeof(VectorizationBase.vfmsub) => :vfmsub,
551-
typeof(VectorizationBase.vfnmadd) => :vfnmadd,
552-
typeof(VectorizationBase.vfnmsub) => :vfnmsub,
523+
typeof(muladd) => :vmuladd_fast,
524+
typeof(fma) => :vfma_fast,
525+
typeof(VectorizationBase.vfma) => :vfma_fast,
526+
typeof(VectorizationBase.vmuladd) => :vmuladd_fast,
527+
typeof(VectorizationBase.vfmsub) => :vfmsub_fast,
528+
typeof(VectorizationBase.vfnmadd) => :vfnmadd_fast,
529+
typeof(VectorizationBase.vfnmsub) => :vfnmsub_fast,
553530
typeof(VectorizationBase.vfma_fast) => :vfma_fast,
554531
typeof(VectorizationBase.vmuladd_fast) => :vmuladd_fast,
555532
typeof(VectorizationBase.vfmsub_fast) => :vfmsub_fast,
@@ -559,30 +536,12 @@ const FUNCTIONSYMBOLS = IdDict{Type{<:Function},Instruction}(
559536
typeof(VectorizationBase.vfmsub231) => :vfmsub231,
560537
typeof(VectorizationBase.vfnmadd231) => :vfnmadd231,
561538
typeof(VectorizationBase.vfnmsub231) => :vfnmsub231,
562-
# typeof(VectorizationBase.vfmadd!) => :vfmadd!,
563-
# typeof(VectorizationBase.vfnmadd!) => :vfnmadd!,
564-
# typeof(VectorizationBase.vfmsub!) => :vfmsub!,
565-
# typeof(VectorizationBase.vfnmsub!) => :vfnmsub!,
566-
# typeof(VectorizationBase.vfmadd_fast) => :vfmadd_fast,
567-
# typeof(VectorizationBase.vfmsub_fast) => :vfmsub_fast,
568-
# typeof(VectorizationBase.vfnmadd_fast) => :vfnmadd_fast,
569-
# typeof(VectorizationBase.vfnmsub_fast) => :vfnmsub_fast,
570-
# typeof(vfmaddaddone) => :vfmaddaddone,
571-
# typeof(vmullog2) => :vmullog2,
572-
# typeof(vmullog2add!) => :vmullog2add!,
573-
# typeof(vmullog10) => :vmullog10,
574-
# typeof(vmullog10add!) => :vmullog10add!,
575-
# typeof(vdivlog2) => :vdivlog2,
576-
# typeof(vdivlog2add!) => :vdivlog2add!,
577-
# typeof(vdivlog10) => :vdivlog10,
578-
# typeof(vdivlog10add!) => :vdivlog10add!,
579-
typeof(sqrt) => :sqrt,
580-
typeof(Base.FastMath.sqrt_fast) => :sqrt,
581-
# typeof(VectorizationBase.vsqrt) => :sqrt,
582-
typeof(log) => :log,
583-
typeof(log2) => :log2,
584-
typeof(log10) => :log10,
585-
typeof(Base.FastMath.log_fast) => :log,
539+
typeof(sqrt) => :sqrt_fast,
540+
typeof(sqrt_fast) => :sqrt_fast,
541+
typeof(log) => :log_fast,
542+
typeof(log2) => :log2_fast,
543+
typeof(log10) => :log10_fast,
544+
typeof(log_fast) => :log_fast,
586545
typeof(log1p) => :log1p,
587546
# typeof(VectorizationBase.vlog) => :log,
588547
typeof(SLEEFPirates.log) => :log,
@@ -593,20 +552,23 @@ const FUNCTIONSYMBOLS = IdDict{Type{<:Function},Instruction}(
593552
typeof(expm1) => :expm1,
594553
# typeof(VectorizationBase.vexp) => :exp,
595554
typeof(SLEEFPirates.exp) => :exp,
596-
typeof(sin) => :sin,
597-
typeof(Base.FastMath.sin_fast) => :sin,
598-
typeof(SLEEFPirates.sin) => :sin,
599-
typeof(cos) => :cos,
600-
typeof(Base.FastMath.cos_fast) => :cos,
601-
typeof(SLEEFPirates.cos) => :cos,
602-
typeof(sincos) => :sincos,
603-
typeof(Base.FastMath.sincos_fast) => :sincos,
604-
typeof(SLEEFPirates.sincos) => :sincos,
555+
typeof(sin) => :sin_fast,
556+
typeof(Base.FastMath.sin_fast) => :sin_fast,
557+
typeof(SLEEFPirates.sin) => :sin_fast,
558+
typeof(cos) => :cos_fast,
559+
typeof(Base.FastMath.cos_fast) => :cos_fast,
560+
typeof(SLEEFPirates.cos) => :cos_fast,
561+
typeof(sincos) => :sincos_fast,
562+
typeof(Base.FastMath.sincos_fast) => :sincos_fast,
563+
typeof(SLEEFPirates.sincos) => :sincos_fast,
564+
typeof(tan) => :tan_fast,
565+
typeof(Base.FastMath.sincos_fast) => :sincos_fast,
566+
typeof(SLEEFPirates.sincos) => :sincos_fast,
605567
typeof(Base.tanh) => :tanh,
606568
typeof(tanh_fast) => :tanh_fast,
607569
typeof(sigmoid_fast) => :sigmoid_fast,
608-
typeof(max) => :max,
609-
typeof(min) => :min,
570+
typeof(max) => :max_fast,
571+
typeof(min) => :min_fast,
610572
typeof(max_fast) => :max_fast,
611573
typeof(min_fast) => :min_fast,
612574
typeof(relu) => :relu,
@@ -618,7 +580,7 @@ const FUNCTIONSYMBOLS = IdDict{Type{<:Function},Instruction}(
618580
typeof(Base.ifelse) => :ifelse,
619581
typeof(ifelse) => :ifelse,
620582
typeof(identity) => :identity,
621-
typeof(conj) => :conj
583+
typeof(conj) => :identity#conj
622584
# typeof(zero) => :zero,
623585
# typeof(one) => :one,
624586
# typeof(axes) => :axes,

src/vectorizationbase_compat/contract_pass.jl

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,14 @@ function append_args_skip!(call, args, i, mod)
5050
call
5151
end
5252

53-
fastfunc(f) = get(VectorizationBase.FASTDICT, f, f)
53+
function fastfunc(f)
54+
i = findfirst(Base.Fix2(===,f), (:sin,:cos,:sincos))
55+
if i === nothing
56+
get(VectorizationBase.FASTDICT, f, f)
57+
else
58+
(:sin_fast,:cos_fast,:sincos_fast)[i]
59+
end
60+
end
5461
function muladd_arguments!(argv, mod, f = first(argv))
5562
if f === :*
5663
argv[1] = :mul_fast

0 commit comments

Comments
 (0)