File tree Expand file tree Collapse file tree 3 files changed +36
-4
lines changed Expand file tree Collapse file tree 3 files changed +36
-4
lines changed Original file line number Diff line number Diff line change @@ -264,8 +264,12 @@ function unflatten(svi::SimpleVarInfo, x::AbstractVector)
264264    vals =  unflatten (svi. values, x)
265265    #  TODO (mhauru) See comment in unflatten in src/varinfo.jl for why this conversion is
266266    #  required but undesireable.
267-     T =  float_type_with_fallback (eltype (x))
268-     accs =  map (acc ->  convert_eltype (T, acc), getaccs (svi))
267+     #  The below line is finicky for type stability. For instance, assigning the eltype to
268+     #  convert to into an intermediate variable makes this unstable (constant propagation)
269+     #  fails. Take care when editing.
270+     accs =  map (
271+         acc ->  convert_eltype (float_type_with_fallback (eltype (x)), acc), getaccs (svi)
272+     )
269273    return  SimpleVarInfo (vals, accs, svi. transformation)
270274end 
271275
Original file line number Diff line number Diff line change @@ -444,8 +444,13 @@ function unflatten(vi::VarInfo, x::AbstractVector)
444444    #  element type of ThreadSafeVarInfo.accs_by_thread. However, doing this conversion here
445445    #  messes with cases like using Float32 of logprobs and Float64 for x. Also, this is just
446446    #  plain ugly and hacky.
447-     T =  float_type_with_fallback (eltype (x))
448-     accs =  map (acc ->  convert_eltype (T, acc), deepcopy (getaccs (vi)))
447+     #  The below line is finicky for type stability. For instance, assigning the eltype to
448+     #  convert to into an intermediate variable makes this unstable (constant propagation)
449+     #  fails. Take care when editing.
450+     accs =  map (
451+         acc ->  convert_eltype (float_type_with_fallback (eltype (x)), acc),
452+         deepcopy (getaccs (vi)),
453+     )
449454    return  VarInfo (md, accs)
450455end 
451456
Original file line number Diff line number Diff line change 732732        end 
733733    end 
734734
735+     @testset  " unflatten type stability"   begin 
736+         @model  function  demo (y)
737+             x ~  Normal ()
738+             y ~  Normal (x, 1 )
739+             return  nothing 
740+         end 
741+ 
742+         model =  demo (0.0 )
743+         varinfos =  DynamicPPL. TestUtils. setup_varinfos (
744+             model, (; x= 1.0 ), (@varname (x),); include_threadsafe= true 
745+         )
746+         @testset  " $(short_varinfo_name (varinfo)) "   for  varinfo in  varinfos
747+             #  Skip the severely inconcrete `SimpleVarInfo` types, since checking for type
748+             #  stability for them doesn't make much sense anyway.
749+             if  varinfo isa  SimpleVarInfo{OrderedDict{Any,Any}} || 
750+                 varinfo isa 
751+                DynamicPPL. ThreadSafeVarInfo{<: SimpleVarInfo{OrderedDict{Any,Any}} }
752+                 continue 
753+             end 
754+             @inferred  DynamicPPL. unflatten (varinfo, varinfo[:])
755+         end 
756+     end 
757+ 
735758    @testset  " subset"   begin 
736759        @model  function  demo_subsetting_varinfo (:: Type{TV} = Vector{Float64}) where  {TV}
737760            s ~  InverseGamma (2 , 3 )
    
 
   
 
     
   
   
          
     
  
    
     
 
    
      
     
 
     
    You can’t perform that action at this time.
  
 
    
  
     
    
      
        
     
 
       
      
     
   
 
    
    
  
 
  
 
     
    
0 commit comments