Skip to content

Commit 6b717f8

Browse files
committed
Added unsafe macros, in that they don't use GC.@preserve. I added these so that you can run the expression generated by a ProbabilityModel line by line in the REPL without having the GC.@preserve blocks, which just made things slightly less aesthetically pleasing. There is no reason to use the unsafe versions in general.
1 parent 9ca5bdd commit 6b717f8

File tree

1 file changed

+88
-163
lines changed

1 file changed

+88
-163
lines changed

src/LoopVectorization.jl

Lines changed: 88 additions & 163 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ const SLEEFPiratesDict = Dict{Symbol,Tuple{Symbol,Symbol}}(
2323
:exp2 => (:SLEEFPirates, :exp2),
2424
:exp10 => (:SLEEFPirates, :exp10),
2525
:expm1 => (:SLEEFPirates, :expm1),
26-
# :sqrt => (:SLEEFPirates, :sqrt), # faster than sqrt_fast
26+
:inv => (:SIMDPirates, :vinv), # faster than sqrt_fast
2727
:sqrt => (:SIMDPirates, :sqrt), # faster than sqrt_fast
2828
:rsqrt => (:SIMDPirates, :rsqrt),
2929
:cbrt => (:SLEEFPirates, :cbrt_fast),
@@ -124,7 +124,7 @@ function replace_syms_i(expr, set, i)
124124
end
125125
end
126126

127-
@noinline function vectorize_body(N, Tsym::Symbol, uf, n, body, vecdict = SLEEFPiratesDict, VType = SVec, mod = :LoopVectorization)
127+
@noinline function vectorize_body(N, Tsym::Symbol, uf, n, body, vecdict = SLEEFPiratesDict, VType = SVec, gcpreserve::Bool = true , mod = :LoopVectorization)
128128
if Tsym == :Float32
129129
vectorize_body(N, Float32, uf, n, body, vecdict, VType, mod)
130130
elseif Tsym == :Float64
@@ -133,7 +133,7 @@ end
133133
throw("Type $Tsym is not supported.")
134134
end
135135
end
136-
@noinline function vectorize_body(N, T::DataType, unroll_factor::Int, n, body, vecdict = SLEEFPiratesDict, VType = SVec, mod = :LoopVectorization)
136+
@noinline function vectorize_body(N, T::DataType, unroll_factor::Int, n, body, vecdict = SLEEFPiratesDict, VType = SVec, gcpreserve::Bool = true, mod = :LoopVectorization)
137137
# unroll_factor == 1 || throw("Only unroll factor of 1 is currently supported. Was set to $unroll_factor.")
138138
T_size = sizeof(T)
139139
if isa(N, Integer)
@@ -322,7 +322,7 @@ end
322322
push!(q.args, nothing)
323323
# display(q)
324324
# We are using pointers, so better add a GC.@preserve.
325-
gcpreserve = true
325+
# gcpreserve = true
326326
# gcpreserve = false
327327
if gcpreserve
328328
return quote
@@ -339,7 +339,8 @@ end
339339

340340
@noinline function add_masks(expr, masksym, reduction_symbols, default_module = :LoopVectorization)
341341
# println("Called add masks!")
342-
postwalk(expr) do x
342+
# postwalk(expr) do x
343+
prewalk(expr) do x
343344
if @capture(x, M_.vstore!(args__))
344345
M === nothing && (M = default_module)
345346
return :($M.vstore!($(args...), $masksym))
@@ -597,168 +598,92 @@ Arguments are
597598
598599
The default type is Float64, and default UnrollFactor is 1 (no unrolling).
599600
"""
600-
macro vectorize(expr)
601-
if @capture(expr, for n_ 1:N_ body__ end)
602-
# q = vectorize_body(N, Float64, n, body, false)
603-
q = vectorize_body(N, Float64, 1, n, body)
604-
# elseif @capture(expr, for n_ ∈ 1:N_ body__ end)
605-
# q = vectorize_body(N, element_type(body)
606-
elseif @capture(expr, for n_ eachindex(A_) body__ end)
607-
q = vectorize_body(:(length($A)), Float64, 1, n, body)
608-
elseif @capture(expr, for n_ eachindex(args__) body__ end)
609-
q = vectorize_body(:(min($([:(length($a)) for a args]...))), Float64, 1, n, body)
610-
else
611-
throw("Could not match loop expression.")
612-
end
613-
esc(q)
614-
end
615-
macro vectorize(type, expr)
616-
if @capture(expr, for n_ 1:N_ body__ end)
617-
# q = vectorize_body(N, type, n, body, true)
618-
q = vectorize_body(N, type, 1, n, body)
619-
elseif @capture(expr, for n_ eachindex(A_) body__ end)
620-
q = vectorize_body(:(length($A)), type, 1, n, body)
621-
elseif @capture(expr, for n_ eachindex(args__) body__ end)
622-
q = vectorize_body(:(min($([:(length($a)) for a args]...))), type, 1, n, body)
623-
else
624-
throw("Could not match loop expression.")
625-
end
626-
esc(q)
627-
end
628-
macro vectorize(unroll_factor::Integer, expr)
629-
if @capture(expr, for n_ 1:N_ body__ end)
630-
# q = vectorize_body(N, type, n, body, true)
631-
q = vectorize_body(N, Float64, unroll_factor, n, body)
632-
elseif @capture(expr, for n_ eachindex(A_) body__ end)
633-
q = vectorize_body(:(length($A)), Float64, unroll_factor, n, body)
634-
elseif @capture(expr, for n_ eachindex(args__) body__ end)
635-
q = vectorize_body(:(min($([:(length($a)) for a args]...))), Float64, unroll_factor, n, body)
636-
else
637-
throw("Could not match loop expression.")
638-
end
639-
esc(q)
640-
end
641-
macro vectorize(type, unroll_factor::Integer, expr)
642-
if @capture(expr, for n_ 1:N_ body__ end)
643-
# q = vectorize_body(N, type, n, body, true)
644-
q = vectorize_body(N, type, unroll_factor, n, body)
645-
elseif @capture(expr, for n_ eachindex(A_) body__ end)
646-
q = vectorize_body(:(length($A)), type, unroll_factor, n, body)
647-
elseif @capture(expr, for n_ eachindex(args__) body__ end)
648-
q = vectorize_body(:(min($([:(length($a)) for a args]...))), type, unroll_factor, n, body)
649-
else
650-
throw("Could not match loop expression.")
651-
end
652-
esc(q)
653-
end
654601

655-
macro vvectorize(expr)
656-
if @capture(expr, for n_ 1:N_ body__ end)
657-
# q = vectorize_body(N, Float64, n, body, false)
658-
q = vectorize_body(N, Float64, 1, n, body, SLEEFPiratesDict, Vec)
659-
# elseif @capture(expr, for n_ ∈ 1:N_ body__ end)
660-
# q = vectorize_body(N, element_type(body)
661-
elseif @capture(expr, for n_ eachindex(A_) body__ end)
662-
q = vectorize_body(:(length($A)), Float64, 1, n, body, SLEEFPiratesDict, Vec)
663-
elseif @capture(expr, for n_ eachindex(args__) body__ end)
664-
q = vectorize_body(:(min($([:(length($a)) for a args]...))), Float64, 1, n, body, SLEEFPiratesDict, Vec)
602+
for vec (false,true)
603+
if vec
604+
V = Vec
605+
macroname = :vvectorize
665606
else
666-
throw("Could not match loop expression.")
607+
V = SVec
608+
macroname = :vectorize
667609
end
668-
esc(q)
669-
end
670-
macro vvectorize(type, expr)
671-
if @capture(expr, for n_ 1:N_ body__ end)
672-
# q = vectorize_body(N, type, n, body, true)
673-
q = vectorize_body(N, type, 1, n, body, SLEEFPiratesDict, Vec)
674-
elseif @capture(expr, for n_ eachindex(A_) body__ end)
675-
q = vectorize_body(:(length($A)), type, 1, n, body, SLEEFPiratesDict, Vec)
676-
elseif @capture(expr, for n_ eachindex(args__) body__ end)
677-
q = vectorize_body(:(min($([:(length($a)) for a args]...))), type, 1, n, body, SLEEFPiratesDict, Vec)
678-
else
679-
throw("Could not match loop expression.")
680-
end
681-
esc(q)
682-
end
683-
macro vvectorize(unroll_factor::Integer, expr)
684-
if @capture(expr, for n_ 1:N_ body__ end)
685-
# q = vectorize_body(N, type, n, body, true)
686-
q = vectorize_body(N, Float64, unroll_factor, n, body, SLEEFPiratesDict, Vec)
687-
elseif @capture(expr, for n_ eachindex(A_) body__ end)
688-
q = vectorize_body(:(length($A)), Float64, unroll_factor, n, body, SLEEFPiratesDict, Vec)
689-
elseif @capture(expr, for n_ eachindex(args__) body__ end)
690-
q = vectorize_body(:(min($([:(length($a)) for a args]...))), Float64, unroll_factor, n, body, SLEEFPiratesDict, Vec)
691-
else
692-
throw("Could not match loop expression.")
693-
end
694-
esc(q)
695-
end
696-
macro vvectorize(type, unroll_factor::Integer, expr)
697-
if @capture(expr, for n_ 1:N_ body__ end)
698-
# q = vectorize_body(N, type, n, body, true)
699-
q = vectorize_body(N, type, unroll_factor, n, body, SLEEFPiratesDict, Vec)
700-
elseif @capture(expr, for n_ eachindex(A_) body__ end)
701-
q = vectorize_body(:(length($A)), type, unroll_factor, n, body, SLEEFPiratesDict, Vec)
702-
elseif @capture(expr, for n_ eachindex(args__) body__ end)
703-
q = vectorize_body(:(min($([:(length($a)) for a args]...))), type, unroll_factor, n, body, SLEEFPiratesDict, Vec)
704-
else
705-
throw("Could not match loop expression.")
706-
end
707-
esc(q)
708-
end
709-
710-
711-
macro vectorize(type, mod::Union{Symbol,Module}, expr)
712-
if @capture(expr, for n_ 1:N_ body__ end)
713-
# q = vectorize_body(N, type, n, body, true)
714-
q = vectorize_body(N, type, 1, n, body, SLEEFPiratesDict, SVec, mod)
715-
elseif @capture(expr, for n_ eachindex(A_) body__ end)
716-
q = vectorize_body(:(length($A)), type, 1, n, body, SLEEFPiratesDict, SVec, mod)
717-
elseif @capture(expr, for n_ eachindex(args__) body__ end)
718-
q = vectorize_body(:(min($([:(length($a)) for a args]...))), type, 1, n, body, SLEEFPiratesDict, SVec, mod)
719-
else
720-
throw("Could not match loop expression.")
721-
end
722-
esc(q)
723-
end
724-
macro vectorize(type, mod::Union{Symbol,Module}, unroll_factor::Integer, expr)
725-
if @capture(expr, for n_ 1:N_ body__ end)
726-
# q = vectorize_body(N, type, n, body, true)
727-
q = vectorize_body(N, type, unroll_factor, n, body, SLEEFPiratesDict, SVec, mod)
728-
elseif @capture(expr, for n_ eachindex(A_) body__ end)
729-
q = vectorize_body(:(length($A)), type, unroll_factor, n, body, SLEEFPiratesDict, SVec, mod)
730-
elseif @capture(expr, for n_ eachindex(args__) body__ end)
731-
q = vectorize_body(:(min($([:(length($a)) for a args]...))), type, unroll_factor, n, body, SLEEFPiratesDict, SVec, mod)
732-
else
733-
throw("Could not match loop expression.")
734-
end
735-
esc(q)
736-
end
737-
macro vvectorize(type, mod::Union{Symbol,Module}, expr)
738-
if @capture(expr, for n_ 1:N_ body__ end)
739-
# q = vectorize_body(N, type, n, body, true)
740-
q = vectorize_body(N, type, 1, n, body, SLEEFPiratesDict, Vec, mod)
741-
elseif @capture(expr, for n_ eachindex(A_) body__ end)
742-
q = vectorize_body(:(length($A)), type, 1, n, body, SLEEFPiratesDict, Vec, mod)
743-
elseif @capture(expr, for n_ eachindex(args__) body__ end)
744-
q = vectorize_body(:(min($([:(length($a)) for a args]...))), type, 1, n, body, SLEEFPiratesDict, Vec, mod)
745-
else
746-
throw("Could not match loop expression.")
747-
end
748-
esc(q)
749-
end
750-
macro vvectorize(type, unroll_factor::Integer, mod::Union{Symbol,Module}, expr)
751-
if @capture(expr, for n_ 1:N_ body__ end)
752-
# q = vectorize_body(N, type, n, body, true)
753-
q = vectorize_body(N, type, unroll_factor, n, body, SLEEFPiratesDict, Vec, mod)
754-
elseif @capture(expr, for n_ eachindex(A_) body__ end)
755-
q = vectorize_body(:(length($A)), type, unroll_factor, n, body, SLEEFPiratesDict, Vec, mod)
756-
elseif @capture(expr, for n_ eachindex(args__) body__ end)
757-
q = vectorize_body(:(min($([:(length($a)) for a args]...))), type, unroll_factor, n, body, SLEEFPiratesDict, Vec, mod)
758-
else
759-
throw("Could not match loop expression.")
610+
for gcpreserve (true,false)
611+
if !gcpreserve
612+
macroname = Symbol(macroname, :_unsafe)
613+
end
614+
@eval macro $macroname(expr)
615+
if @capture(expr, for n_ 1:N_ body__ end)
616+
q = vectorize_body(N, Float64, 1, n, body, SLEEFPiratesDict, $V, $gcpreserve)
617+
elseif @capture(expr, for n_ eachindex(A_) body__ end)
618+
q = vectorize_body(:(length($A)), Float64, 1, n, body, SLEEFPiratesDict, $V, $gcpreserve)
619+
elseif @capture(expr, for n_ eachindex(args__) body__ end)
620+
q = vectorize_body(:(min($([:(length($a)) for a args]...))), Float64, 1, n, body, SLEEFPiratesDict, $V, $gcpreserve)
621+
else
622+
throw("Could not match loop expression.")
623+
end
624+
esc(q)
625+
end
626+
@eval macro $macroname(type, expr)
627+
if @capture(expr, for n_ 1:N_ body__ end)
628+
q = vectorize_body(N, type, 1, n, body, SLEEFPiratesDict, $V, $gcpreserve)
629+
elseif @capture(expr, for n_ eachindex(A_) body__ end)
630+
q = vectorize_body(:(length($A)), type, 1, n, body, SLEEFPiratesDict, $V, $gcpreserve)
631+
elseif @capture(expr, for n_ eachindex(args__) body__ end)
632+
q = vectorize_body(:(min($([:(length($a)) for a args]...))), type, 1, n, body, SLEEFPiratesDict, $V, $gcpreserve)
633+
else
634+
throw("Could not match loop expression.")
635+
end
636+
esc(q)
637+
end
638+
@eval macro $macroname(unroll_factor::Integer, expr)
639+
if @capture(expr, for n_ 1:N_ body__ end)
640+
q = vectorize_body(N, Float64, unroll_factor, n, body, SLEEFPiratesDict, $V, $gcpreserve)
641+
elseif @capture(expr, for n_ eachindex(A_) body__ end)
642+
q = vectorize_body(:(length($A)), Float64, unroll_factor, n, body, SLEEFPiratesDict, $V, $gcpreserve)
643+
elseif @capture(expr, for n_ eachindex(args__) body__ end)
644+
q = vectorize_body(:(min($([:(length($a)) for a args]...))), Float64, unroll_factor, n, body, SLEEFPiratesDict, $V, $gcpreserve)
645+
else
646+
throw("Could not match loop expression.")
647+
end
648+
esc(q)
649+
end
650+
@eval macro $macroname(type, unroll_factor::Integer, expr)
651+
if @capture(expr, for n_ 1:N_ body__ end)
652+
q = vectorize_body(N, type, unroll_factor, n, body, SLEEFPiratesDict, $V, $gcpreserve)
653+
elseif @capture(expr, for n_ eachindex(A_) body__ end)
654+
q = vectorize_body(:(length($A)), type, unroll_factor, n, body, SLEEFPiratesDict, $V, $gcpreserve)
655+
elseif @capture(expr, for n_ eachindex(args__) body__ end)
656+
q = vectorize_body(:(min($([:(length($a)) for a args]...))), type, unroll_factor, n, body, SLEEFPiratesDict, $V, $gcpreserve)
657+
else
658+
throw("Could not match loop expression.")
659+
end
660+
esc(q)
661+
end
662+
@eval macro $macroname(type, mod::Union{Symbol,Module}, expr)
663+
if @capture(expr, for n_ 1:N_ body__ end)
664+
q = vectorize_body(N, type, 1, n, body, SLEEFPiratesDict, $V, $gcpreserve, mod)
665+
elseif @capture(expr, for n_ eachindex(A_) body__ end)
666+
q = vectorize_body(:(length($A)), type, 1, n, body, SLEEFPiratesDict, $V, $gcpreserve, mod)
667+
elseif @capture(expr, for n_ eachindex(args__) body__ end)
668+
q = vectorize_body(:(min($([:(length($a)) for a args]...))), type, 1, n, body, SLEEFPiratesDict, $V, $gcpreserve, mod)
669+
else
670+
throw("Could not match loop expression.")
671+
end
672+
esc(q)
673+
end
674+
@eval macro $macroname(type, mod::Union{Symbol,Module}, unroll_factor::Integer, expr)
675+
if @capture(expr, for n_ 1:N_ body__ end)
676+
q = vectorize_body(N, type, unroll_factor, n, body, SLEEFPiratesDict, $V, mod)
677+
elseif @capture(expr, for n_ eachindex(A_) body__ end)
678+
q = vectorize_body(:(length($A)), type, unroll_factor, n, body, SLEEFPiratesDict, $V, mod)
679+
elseif @capture(expr, for n_ eachindex(args__) body__ end)
680+
q = vectorize_body(:(min($([:(length($a)) for a args]...))), type, unroll_factor, n, body, SLEEFPiratesDict, $V, $gcpreserve, mod)
681+
else
682+
throw("Could not match loop expression.")
683+
end
684+
esc(q)
685+
end
760686
end
761-
esc(q)
762687
end
763688

764689
end # module

0 commit comments

Comments
 (0)