Skip to content

Commit a73a797

Browse files
authored
combine broadcast statements (#457)
* combine broadcast statements * fix bc and add rrules
1 parent 8ba69ae commit a73a797

File tree

3 files changed

+35
-51
lines changed

3 files changed

+35
-51
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "LoopVectorization"
22
uuid = "bdcacae8-1622-11e9-2a5c-532679323890"
33
authors = ["Chris Elrod <[email protected]>"]
4-
version = "0.12.145"
4+
version = "0.12.146"
55

66
[deps]
77
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"

src/broadcast.jl

Lines changed: 29 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -525,15 +525,15 @@ function add_broadcast_loops!(ls::LoopSet, loopsyms::Vector{Symbol}, destsym::Sy
525525
end
526526
end
527527

528-
# size of dest determines loops
529-
# function vmaterialize!(
530-
@generated function vmaterialize!(
531-
dest::AbstractArray{T,N},
532-
bc::BC,
533-
::Val{Mod},
534-
::Val{UNROLL},
535-
::Val{dontbc},
536-
) where {T<:NativeTypes,N,BC<:Union{Broadcasted,Product},Mod,UNROLL,dontbc}
528+
function vmaterialize_fun(
529+
sizeofT::Int,
530+
N,
531+
@nospecialize(_::Type{BC}),
532+
Mod,
533+
UNROLL,
534+
dontbc,
535+
transpose::Bool,
536+
) where {BC}
537537
# 2 + 1
538538
# we have an N dimensional loop.
539539
# need to construct the LoopSet
@@ -542,17 +542,20 @@ end
542542
set_hw!(ls, rs, rc, cls)
543543
ls.isbroadcast = isbroadcast # maybe set `false` in a DiffEq-like `@..` macro
544544
loopsyms = [gensym!(ls, "n") for _ 1:N]
545-
add_broadcast_loops!(ls, loopsyms, :dest)
546-
elementbytes = sizeof(T)
545+
transpose && pushprepreamble!(ls, Expr(:(=), :dest, Expr(:call, :parent, :dest′)))
546+
ret = transpose ? :dest′ : :dest
547+
add_broadcast_loops!(ls, loopsyms, ret)
548+
elementbytes = sizeofT
547549
add_broadcast!(ls, :destination, :bc, loopsyms, BC, dontbc, elementbytes)
550+
transpose && reverse!(loopsyms)
548551
storeop =
549552
add_simple_store!(ls, :destination, ArrayReference(:dest, loopsyms), elementbytes)
550553
doaddref!(ls, storeop)
551554
resize!(ls.loop_order, num_loops(ls)) # num_loops may be greater than N, eg Product
552555
# return ls
553556
sc = setup_call(
554557
ls,
555-
:(Base.Broadcast.materialize!(dest, bc)),
558+
:(Base.Broadcast.materialize!($ret, bc)),
556559
LineNumberNode(0),
557560
inline,
558561
false,
@@ -563,7 +566,19 @@ end
563566
warncheckarg,
564567
safe,
565568
)
566-
Expr(:block, Expr(:meta, :inline), sc, :dest)
569+
Expr(:block, Expr(:meta, :inline), sc, ret)
570+
end
571+
572+
# size of dest determines loops
573+
# function vmaterialize!(
574+
@generated function vmaterialize!(
575+
dest::AbstractArray{T,N},
576+
bc::BC,
577+
::Val{Mod},
578+
::Val{UNROLL},
579+
::Val{dontbc},
580+
) where {T<:NativeTypes,N,BC<:Union{Broadcasted,Product},Mod,UNROLL,dontbc}
581+
vmaterialize_fun(sizeof(T), N, BC, Mod, UNROLL, dontbc, false)
567582
end
568583
@generated function vmaterialize!(
569584
dest′::Union{Adjoint{T,A},Transpose{T,A}},
@@ -580,43 +595,7 @@ end
580595
UNROLL,
581596
dontbc,
582597
}
583-
# we have an N dimensional loop.
584-
# need to construct the LoopSet
585-
ls = LoopSet(Mod)
586-
inline, u₁, u₂, v, isbroadcast, _, rs, rc, cls, threads, warncheckarg, safe = UNROLL
587-
set_hw!(ls, rs, rc, cls)
588-
ls.isbroadcast = isbroadcast # maybe set `false` in a DiffEq-like `@..` macro
589-
loopsyms = [gensym!(ls, "n") for _ 1:N]
590-
pushprepreamble!(ls, Expr(:(=), :dest, Expr(:call, :parent, :dest′)))
591-
add_broadcast_loops!(ls, loopsyms, :dest′)
592-
elementbytes = sizeof(T)
593-
add_broadcast!(ls, :destination, :bc, loopsyms, BC, dontbc, elementbytes)
594-
storeop = add_simple_store!(
595-
ls,
596-
:destination,
597-
ArrayReference(:dest, reverse(loopsyms)),
598-
elementbytes,
599-
)
600-
doaddref!(ls, storeop)
601-
resize!(ls.loop_order, num_loops(ls)) # num_loops may be greater than N, eg Product
602-
Expr(
603-
:block,
604-
Expr(:meta, :inline),
605-
setup_call(
606-
ls,
607-
:(Base.Broadcast.materialize!(dest′, bc)),
608-
LineNumberNode(0),
609-
inline,
610-
false,
611-
u₁,
612-
u₂,
613-
v,
614-
threads % Int,
615-
warncheckarg,
616-
safe,
617-
),
618-
:dest′,
619-
)
598+
vmaterialize_fun(sizeof(T), N, BC, Mod, UNROLL, dontbc, true)
620599
end
621600
# these are marked `@inline` so the `@turbo` itself can choose whether or not to inline.
622601
@generated function vmaterialize!(

src/simdfunctionals/vmap_grad_rrule.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,3 +100,8 @@ function ChainRulesCore.rrule(::typeof(vmap), f::F, args::Vararg{Any,K}) where {
100100
∂vmap_singlethread!(f, jacs, out, args...)
101101
out, SIMDMapBack(jacs)
102102
end
103+
for f in (:vmapt, :vmapnt, :vmapntt)
104+
@eval function ChainRulesCore.rrule(::typeof($f), f::F, args::Vararg{Any,K}) where {F,K}
105+
ChainRulesCore.rrule(typeof(vmap), f, args...)
106+
end
107+
end

0 commit comments

Comments
 (0)