File tree Expand file tree Collapse file tree 3 files changed +9
-2
lines changed Expand file tree Collapse file tree 3 files changed +9
-2
lines changed Original file line number Diff line number Diff line change 5151    _import_structure ["controlnets.controlnet_xs" ] =  ["ControlNetXSAdapter" , "UNetControlNetXSModel" ]
5252    _import_structure ["controlnets.multicontrolnet" ] =  ["MultiControlNetModel" ]
5353    _import_structure ["embeddings" ] =  ["ImageProjection" ]
54+     _import_structure ["layerwise_upcasting_utils" ] =  [
55+         "LayerwiseUpcastingGranularity" ,
56+         "apply_layerwise_upcasting" ,
57+         "apply_layerwise_upcasting_hook" ,
58+     ]
5459    _import_structure ["modeling_utils" ] =  ["ModelMixin" ]
5560    _import_structure ["transformers.auraflow_transformer_2d" ] =  ["AuraFlowTransformer2DModel" ]
5661    _import_structure ["transformers.cogvideox_transformer_3d" ] =  ["CogVideoXTransformer3DModel" ]
Original file line number Diff line number Diff line change @@ -321,6 +321,7 @@ def enable_layerwise_upcasting(
321321        storage_dtype : torch .dtype  =  torch .float8_e4m3fn ,
322322        compute_dtype : Optional [torch .dtype ] =  None ,
323323        granularity : LayerwiseUpcastingGranularity  =  LayerwiseUpcastingGranularity .PYTORCH_LAYER ,
324+         skip_modules_pattern : Optional [List [str ]] =  None ,
324325    ) ->  None :
325326        r""" 
326327        Activates layerwise upcasting for the current model. 
@@ -364,7 +365,8 @@ def enable_layerwise_upcasting(
364365                [`~LayerwiseUpcastingGranularity`] for more information. 
365366        """ 
366367
367-         skip_modules_pattern  =  []
368+         if  skip_modules_pattern  is  None :
369+             skip_modules_pattern  =  []
368370        if  self ._keep_in_fp32_modules  is  not None :
369371            skip_modules_pattern .extend (self ._keep_in_fp32_modules )
370372        if  self ._always_upcast_modules  is  not None :
Original file line number Diff line number Diff line change @@ -836,7 +836,7 @@ def __call__(
836836                if  i  ==  len (timesteps ) -  1  or  ((i  +  1 ) >  num_warmup_steps  and  (i  +  1 ) %  self .scheduler .order  ==  0 ):
837837                    progress_bar .update ()
838838
839-         if  not  output_type  ==  "latents " :
839+         if  not  output_type  ==  "latent " :
840840            video  =  self .decode_latents (latents , video_length , decode_chunk_size = 14 )
841841            video  =  self .video_processor .postprocess_video (video = video , output_type = output_type )
842842        else :
 
 
   
 
     
   
   
          
    
    
     
    
      
     
     
    You can’t perform that action at this time.
  
 
    
  
    
      
        
     
       
      
     
   
 
    
    
  
 
  
 
     
    
0 commit comments