Skip to content

Commit 9894290

Browse files
committed
Better handling of function interpolation, lower bound ArrayInterface to a version testing that size isn't broken on vectors.
1 parent e748186 commit 9894290

File tree

6 files changed

+48
-17
lines changed

6 files changed

+48
-17
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
1414
VectorizationBase = "3d5dd08c-fd9d-11e8-17fa-ed2836048c2f"
1515

1616
[compat]
17-
ArrayInterface = "2.14.2"
17+
ArrayInterface = "2.14.9"
1818
DocStringExtensions = "0.8"
1919
IfElse = "0.1"
2020
OffsetArrays = "1.4.1"

src/costs.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -463,11 +463,11 @@ const FUNCTIONSYMBOLS = IdDict{Type{<:Function},Instruction}(
463463
typeof(Base.ifelse) => :ifelse,
464464
typeof(ifelse) => :ifelse,
465465
typeof(identity) => :identity,
466-
typeof(conj) => :conj,
467-
typeof(zero) => :zero,
468-
typeof(one) => :one,
469-
typeof(axes) => :axes,
470-
typeof(eltype) => :eltype
466+
typeof(conj) => :conj
467+
# typeof(zero) => :zero,
468+
# typeof(one) => :one,
469+
# typeof(axes) => :axes,
470+
# typeof(eltype) => :eltype
471471
)
472472

473473
# implement whitelist for avx_support that package authors may use to conservatively guard `@avx` application

src/graphs.jl

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -573,13 +573,18 @@ function instruction!(ls::LoopSet, x::Expr)
573573
if instr keys(COST)
574574
instr = gensym(:f)
575575
pushpreamble!(ls, Expr(:(=), instr, x))
576+
Instruction(Symbol(""), instr)
577+
else
578+
Instruction(:LoopVectorization, instr)
576579
end
577-
Instruction(Symbol(""), instr)
578580
end
579581
instruction!(ls::LoopSet, x::Symbol) = instruction(x)
580-
function instruction!(ls::LoopSet, ::F) where {F <: Function}
581-
FUNCTIONSYMBOLS[F]
582-
# get(FUNCTIONSYMBOLS, F,
582+
function instruction!(ls::LoopSet, f::F) where {F <: Function}
583+
get(FUNCTIONSYMBOLS, F) do
584+
instr = gensym(:f)
585+
pushpreamble!(ls, Expr(:(=), instr, f))
586+
Instruction(Symbol(""), instr)
587+
end
583588
end
584589

585590

src/lowering.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -652,13 +652,13 @@ end
652652

653653
function determine_eltype(ls::LoopSet)
654654
if length(ls.includedactualarrays) == 0
655-
return Expr(:call, :typeof, 0)
655+
return Expr(:call, lv(:typeof), 0)
656656
elseif length(ls.includedactualarrays) == 1
657-
return Expr(:call, :eltype, first(ls.includedactualarrays))
657+
return Expr(:call, lv(:eltype), first(ls.includedactualarrays))
658658
end
659-
promote_q = Expr(:call, :promote_type)
659+
promote_q = Expr(:call, lv(:promote_type))
660660
for array ls.includedactualarrays
661-
push!(promote_q.args, Expr(:call, :eltype, array))
661+
push!(promote_q.args, Expr(:call, lv(:eltype), array))
662662
end
663663
promote_q
664664
end

src/reconstruct_loopset.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ function Loop(ls::LoopSet, ex::Expr, sym::Symbol, ::Type{<:AbstractUnitRange})
55
start = gensym(ssym*"_loopstart"); stop = gensym(ssym*"_loopstop"); loopsym = gensym(ssym * "_loop")
66
pushpreamble!(ls, Expr(:(=), loopsym, ex))
77
pushpreamble!(ls, Expr(:(=), start, Expr(:call, :first, loopsym)))
8-
pushpreamble!(ls, Expr(:(=), stop, Expr(:call, :last, loopsym)))
8+
pushpreamble!(ls, Expr(:(=), stop, Expr(:call, lv(:last), loopsym)))
99
loop = Loop(sym, 1, 1024, start, stop, false, false)::Loop
1010
pushpreamble!(ls, loopiteratesatleastonce(loop))
1111
loop
@@ -14,14 +14,14 @@ end
1414

1515
function Loop(ls::LoopSet, ex::Expr, sym::Symbol, ::Type{OptionallyStaticUnitRange{I, Static{U}}}) where {I<:Integer, U}
1616
start = gensym(String(sym)*"_loopstart")
17-
pushpreamble!(ls, Expr(:(=), start, Expr(:call, :first, ex)))
17+
pushpreamble!(ls, Expr(:(=), start, Expr(:call, lv(:first), ex)))
1818
loop = Loop(sym, U - 1024, U, start, Symbol(""), false, true)::Loop
1919
pushpreamble!(ls, loopiteratesatleastonce(loop))
2020
loop
2121
end
2222
function Loop(ls::LoopSet, ex::Expr, sym::Symbol, ::Type{OptionallyStaticUnitRange{Static{L}, I}}) where {I <: Integer, L}
2323
stop = gensym(String(sym)*"_loopstop")
24-
pushpreamble!(ls, Expr(:(=), stop, Expr(:call, :last, ex)))
24+
pushpreamble!(ls, Expr(:(=), stop, Expr(:call, lv(:last), ex)))
2525
loop = Loop(sym, L, L + 1024, Symbol(""), stop, true, false)::Loop
2626
pushpreamble!(ls, loopiteratesatleastonce(loop))
2727
loop

test/miscellaneous.jl

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -787,6 +787,31 @@ function maybe_const_issue144_avx!(𝛥mat, 𝛥ℛ, mat, ℛ)
787787
end
788788
𝛥mat
789789
end
790+
function grad!(𝛥x, 𝛥ℛ, x, 𝒶𝓍i=eachindex(x))
791+
for i = 𝒶𝓍i
792+
(i >= first(axes(𝛥x, 1))) & (i <= last(axes(𝛥x, 1))) && (𝛥x[i] = 𝛥x[i] + 𝛥ℛ[i])
793+
end
794+
𝛥x
795+
end
796+
function grad_avx!(𝛥x, 𝛥ℛ, x, 𝒶𝓍i=eachindex(x))
797+
@avx for i = 𝒶𝓍i
798+
(i >= first(axes(𝛥x, 1))) & (i <= last(axes(𝛥x, 1))) && (𝛥x[i] = 𝛥x[i] + 𝛥ℛ[i])
799+
end
800+
𝛥x
801+
end
802+
function grad_avx_base!(𝛥x, 𝛥ℛ, x, 𝒶𝓍i=eachindex(x))
803+
@avx for i = 𝒶𝓍i
804+
(i >= first(axes(𝛥x, 1))) & (i <= Base.last(axes(𝛥x, 1))) && (𝛥x[i] = 𝛥x[i] + 𝛥ℛ[i])
805+
end
806+
𝛥x
807+
end
808+
@eval function grad_avx_eval!(𝛥x, 𝛥ℛ, x, 𝒶𝓍i=eachindex(x))
809+
@avx for i = 𝒶𝓍i
810+
(i >= $first($axes(𝛥x, 1))) & (i <= $last($axes(𝛥x, 1))) && (𝛥x[i] = 𝛥x[i] + 𝛥ℛ[i])
811+
end
812+
𝛥x
813+
end # LoadError: KeyError: key typeof(first) not found
814+
790815

791816
for T (Float32, Float64)
792817
@show T, @__LINE__
@@ -1015,6 +1040,7 @@ end
10151040
rtol = (eps(T))
10161041
)
10171042

1043+
@test grad!(zeros(5), ones(5), ones(3)) grad_avx!(zeros(5), ones(5), ones(3)) grad_avx_base!(zeros(5), ones(5), ones(3)) grad_avx_eval!(zeros(5), ones(5), ones(3))
10181044
end
10191045
for T [Int16, Int32, Int64]
10201046
n = 8sizeof(T) - 1

0 commit comments

Comments
 (0)