File tree Expand file tree Collapse file tree 2 files changed +3
-5
lines changed 
src/diffusers/models/unets Expand file tree Collapse file tree 2 files changed +3
-5
lines changed Original file line number Diff line number Diff line change @@ -295,8 +295,7 @@ def forward(
295295        # timesteps does not contain any weights and will always return f32 tensors 
296296        # but time_embedding might actually be running in fp16. so we need to cast here. 
297297        # there might be better ways to encapsulate this. 
298-         # TODO(aryan): Need to have this reviewed 
299-         t_emb  =  t_emb .to (dtype = sample .dtype )
298+         t_emb  =  t_emb .to (dtype = self .dtype )
300299        emb  =  self .time_embedding (t_emb )
301300
302301        if  self .class_embedding  is  not None :
@@ -306,7 +305,7 @@ def forward(
306305            if  self .config .class_embed_type  ==  "timestep" :
307306                class_labels  =  self .time_proj (class_labels )
308307
309-             class_emb  =  self .class_embedding (class_labels ).to (dtype = sample .dtype )
308+             class_emb  =  self .class_embedding (class_labels ).to (dtype = self .dtype )
310309            emb  =  emb  +  class_emb 
311310        elif  self .class_embedding  is  None  and  class_labels  is  not None :
312311            raise  ValueError ("class_embedding needs to be initialized in order to use class conditioning" )
Original file line number Diff line number Diff line change @@ -2133,8 +2133,7 @@ def forward(
21332133        # timesteps does not contain any weights and will always return f32 tensors 
21342134        # but time_embedding might actually be running in fp16. so we need to cast here. 
21352135        # there might be better ways to encapsulate this. 
2136-         # TODO(aryan): Need to have this reviewed 
2137-         t_emb  =  t_emb .to (dtype = sample .dtype )
2136+         t_emb  =  t_emb .to (dtype = self .dtype )
21382137
21392138        emb  =  self .time_embedding (t_emb , timestep_cond )
21402139        aug_emb  =  None 
 
 
   
 
     
   
   
          
    
    
     
    
      
     
     
    You can’t perform that action at this time.
  
 
    
  
    
      
        
     
       
      
     
   
 
    
    
  
 
  
 
     
    
0 commit comments