@@ -333,7 +333,7 @@ def get_control_advanced(self, x_noisy, t, cond, batched_number, transformer_opt
333333 if cond .get ('c_concat' , None ) is not None :
334334 x_noisy = torch .cat ([x_noisy ] + [cond ['c_concat' ]], dim = 1 )
335335
336- control = self .control_model (x = x_noisy .to (dtype ), hint = self .cond_hint , timesteps = timestep .float (), context = context . to ( dtype ), y = y , cond = cond )
336+ control = self .control_model (x = x_noisy .to (dtype ), hint = self .cond_hint , timesteps = timestep .float (), context = comfy . model_management . cast_to_device ( context , x_noisy . device , dtype ), y = y , cond = cond )
337337 return self .control_merge (control , control_prev , output_dtype )
338338
339339 def copy (self ):
@@ -463,7 +463,7 @@ def get_control_advanced(self, x_noisy: Tensor, t, cond, batched_number: int, tr
463463 timestep = self .model_sampling_current .timestep (t )
464464 x_noisy = self .model_sampling_current .calculate_input (t , x_noisy )
465465
466- control = self .control_model (x = x_noisy .to (dtype ), hint = self .cond_hint , timesteps = timestep .float (), context = context . to ( dtype ), y = y )
466+ control = self .control_model (x = x_noisy .to (dtype ), hint = self .cond_hint , timesteps = timestep .float (), context = comfy . model_management . cast_to_device ( context , x_noisy . device , dtype ), y = y )
467467 return self .control_merge (control , control_prev , output_dtype )
468468
469469 def apply_advanced_strengths_and_masks (self , x : Tensor , batched_number : int , * args , ** kwargs ):
0 commit comments