-
Notifications
You must be signed in to change notification settings - Fork 0
Open
Description
function flip(f, xs)
rev_time = reverse(eachindex(xs))
return getindex.(Ref(f.(getindex.(Ref(xs), rev_time))), rev_time)
# the same as
# flipped_xs = Buffer(xs)
# @inbounds for t ∈ rev_time
# flipped_xs[t] = f(xs[t])
# end
# return copy(flipped_xs)
# but implemented via broadcasting as Zygote differentiates loops much slower than broadcasting
endreturns Any
julia> @code_warntype flip(m, xs)
Variables
#self#::Core.Compiler.Const(flip, false)
f::Recur{LSTMCell{Array{Float32,2},Array{Float32,1}}}
xs::Array{Array{Float32,2},1}
rev_time::StepRange{Int64,Int64}
Body::Any
1 ─ %1 = Main.eachindex(xs)::Base.OneTo{Int64}
│ (rev_time = Main.reverse(%1))
│ %3 = Main.Ref(xs)::Base.RefValue{Array{Array{Float32,2},1}}
│ %4 = Base.broadcasted(Main.getindex, %3, rev_time)::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,typeof(getindex),Tuple{Base.RefValue{Array{Array{Float32,2},1}},StepRange{Int64,Int64}}}
│ %5 = Base.broadcasted(f, %4)::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,Recur{LSTMCell{Array{Float32,2},Array{Float32,1}}},Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,typeof(getindex),Tuple{Base.RefValue{Array{Array{Float32,2},1}},StepRange{Int64,Int64}}}}}
│ %6 = Base.materialize(%5)::Any
│ %7 = Main.Ref(%6)::Base.RefValue{_A} where _A
│ %8 = Base.broadcasted(Main.getindex, %7, rev_time)::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,typeof(getindex),_A} where _A<:Tuple
│ %9 = Base.materialize(%8)::Any
└── return %9alleviated by
function flip(f, xs::T) where T
rev_time = reverse(eachindex(xs))
return getindex.(Ref(
f.(getindex.(Ref(xs), rev_time))::T
), rev_time)
# the same as
# flipped_xs = Buffer(xs)
# @inbounds for t ∈ rev_time
# flipped_xs[t] = f(xs[t])
# end
# return copy(flipped_xs)
# but implemented via broadcasting as Zygote differentiates loops much slower than broadcasting
endjulia> @code_warntype flip(m, xs)
Variables
#self#::Core.Compiler.Const(flip, false)
f::Recur{LSTMCell{Array{Float32,2},Array{Float32,1}}}
xs::Array{Array{Float32,2},1}
rev_time::StepRange{Int64,Int64}
Body::Array{Array{Float32,2},1}
1 ─ %1 = Main.eachindex(xs)::Base.OneTo{Int64}
│ (rev_time = Main.reverse(%1))
│ %3 = Main.Ref(xs)::Base.RefValue{Array{Array{Float32,2},1}}
│ %4 = Base.broadcasted(Main.getindex, %3, rev_time)::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,typeof(getindex),Tuple{Base.RefValue{Array{Array{Float32,2},1}},StepRange{Int64,Int64}}}
│ %5 = Base.broadcasted(f, %4)::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,Recur{LSTMCell{Array{Float32,2},Array{Float32,1}}},Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,typeof(getindex),Tuple{Base.RefValue{Array{Array{Float32,2},1}},StepRange{Int64,Int64}}}}}
│ %6 = Base.materialize(%5)::Any
│ %7 = Core.typeassert(%6, $(Expr(:static_parameter, 1)))::Array{Array{Float32,2},1}
│ %8 = Main.Ref(%7)::Base.RefValue{Array{Array{Float32,2},1}}
│ %9 = Base.broadcasted(Main.getindex, %8, rev_time)::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,typeof(getindex),Tuple{Base.RefValue{Array{Array{Float32,2},1}},StepRange{Int64,Int64}}}
│ %10 = Base.materialize(%9)::Array{Array{Float32,2},1}
└── return %10Metadata
Metadata
Assignees
Labels
No labels