Skip to content

@code_warntype outputs for functions participating in the forward pass #7

@AzamatB

Description

@AzamatB
function flip(f, xs)
   rev_time = reverse(eachindex(xs))
   return getindex.(Ref(f.(getindex.(Ref(xs), rev_time))), rev_time)
   # the same as
   # flipped_xs = Buffer(xs)
   # @inbounds for t ∈ rev_time
   #    flipped_xs[t] = f(xs[t])
   # end
   # return copy(flipped_xs)
   # but implemented via broadcasting as Zygote differentiates loops much slower than broadcasting
end

returns Any

julia> @code_warntype flip(m, xs)
Variables
  #self#::Core.Compiler.Const(flip, false)
  f::Recur{LSTMCell{Array{Float32,2},Array{Float32,1}}}
  xs::Array{Array{Float32,2},1}
  rev_time::StepRange{Int64,Int64}

Body::Any
1%1 = Main.eachindex(xs)::Base.OneTo{Int64}
│        (rev_time = Main.reverse(%1))
│   %3 = Main.Ref(xs)::Base.RefValue{Array{Array{Float32,2},1}}%4 = Base.broadcasted(Main.getindex, %3, rev_time)::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,typeof(getindex),Tuple{Base.RefValue{Array{Array{Float32,2},1}},StepRange{Int64,Int64}}}
│   %5 = Base.broadcasted(f, %4)::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,Recur{LSTMCell{Array{Float32,2},Array{Float32,1}}},Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,typeof(getindex),Tuple{Base.RefValue{Array{Array{Float32,2},1}},StepRange{Int64,Int64}}}}}
│   %6 = Base.materialize(%5)::Any%7 = Main.Ref(%6)::Base.RefValue{_A} where _A
│   %8 = Base.broadcasted(Main.getindex, %7, rev_time)::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,typeof(getindex),_A} where _A<:Tuple%9 = Base.materialize(%8)::Any
└──      return %9

alleviated by

function flip(f, xs::T) where T
   rev_time = reverse(eachindex(xs))
   return getindex.(Ref(
      f.(getindex.(Ref(xs), rev_time))::T
   ), rev_time)
   # the same as
   # flipped_xs = Buffer(xs)
   # @inbounds for t ∈ rev_time
   #    flipped_xs[t] = f(xs[t])
   # end
   # return copy(flipped_xs)
   # but implemented via broadcasting as Zygote differentiates loops much slower than broadcasting
end
julia> @code_warntype flip(m, xs)
Variables
  #self#::Core.Compiler.Const(flip, false)
  f::Recur{LSTMCell{Array{Float32,2},Array{Float32,1}}}
  xs::Array{Array{Float32,2},1}
  rev_time::StepRange{Int64,Int64}

Body::Array{Array{Float32,2},1}
1%1  = Main.eachindex(xs)::Base.OneTo{Int64}
│         (rev_time = Main.reverse(%1))
│   %3  = Main.Ref(xs)::Base.RefValue{Array{Array{Float32,2},1}}%4  = Base.broadcasted(Main.getindex, %3, rev_time)::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,typeof(getindex),Tuple{Base.RefValue{Array{Array{Float32,2},1}},StepRange{Int64,Int64}}}
│   %5  = Base.broadcasted(f, %4)::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,Recur{LSTMCell{Array{Float32,2},Array{Float32,1}}},Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,typeof(getindex),Tuple{Base.RefValue{Array{Array{Float32,2},1}},StepRange{Int64,Int64}}}}}
│   %6  = Base.materialize(%5)::Any%7  = Core.typeassert(%6, $(Expr(:static_parameter, 1)))::Array{Array{Float32,2},1}%8  = Main.Ref(%7)::Base.RefValue{Array{Array{Float32,2},1}}%9  = Base.broadcasted(Main.getindex, %8, rev_time)::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,typeof(getindex),Tuple{Base.RefValue{Array{Array{Float32,2},1}},StepRange{Int64,Int64}}}
│   %10 = Base.materialize(%9)::Array{Array{Float32,2},1}
└──       return %10

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions