@@ -527,72 +527,13 @@ def _convert_to_beta(
527527 )
528528 return sigmas
529529
530- def convert_noise_to_x0 (self , model_output : torch .Tensor , sample : torch .Tensor , timestep : int ) -> torch .Tensor :
531- """
532- Convert to original sample x0 from noise prediction.
533-
534- Args:
535- model_output (`torch.Tensor`): The model output.
536- sample (`torch.Tensor`): A current instance of a sample created by the diffusion process.
537- timestep (`int`): The current discrete timestep in the diffusion chain.
538-
539- Returns:
540- `torch.Tensor`: The predicted original sample (x0).
541- """
542- if self .step_index is None :
543- self ._init_step_index (timestep )
544- sigma = self .sigmas [self .step_index ]
545- alpha_t , sigma_t = self ._sigma_to_alpha_sigma_t (sigma )
546-
547- if self .config .prediction_type == "epsilon" :
548- return (sample - sigma_t * model_output ) / alpha_t
549- elif self .config .prediction_type == "sample" :
550- return model_output
551- elif self .config .prediction_type == "v_prediction" :
552- return alpha_t * sample - sigma_t * model_output
553- else :
554- raise ValueError (
555- f"prediction_type given as { self .config .prediction_type } must be one of `epsilon`, `sample`, or"
556- " `v_prediction`."
557- )
558-
559- def convert_x0_to_noise (self , pred_x0 : torch .Tensor , sample : torch .Tensor , timestep : int ) -> torch .Tensor :
560- """
561- Convert to noise prediction from original sample x0.
562-
563- Args:
564- pred_x0 (`torch.Tensor`): The predicted original sample (x0).
565- sample (`torch.Tensor`): A current instance of a sample created by the diffusion process.
566- timestep (`int`): The current discrete timestep in the diffusion chain.
567-
568- Returns:
569- `torch.Tensor`: The converted noise prediction.
570- """
571- if self .step_index is None :
572- self ._init_step_index (timestep )
573- sigma = self .sigmas [self .step_index ]
574- alpha_t , sigma_t = self ._sigma_to_alpha_sigma_t (sigma )
575-
576- if self .config .prediction_type == "epsilon" :
577- x0_pred = (sample - alpha_t * pred_x0 ) / sigma_t
578- elif self .config .prediction_type == "sample" :
579- x0_pred = pred_x0
580- elif self .config .prediction_type == "v_prediction" :
581- x0_pred = alpha_t * pred_x0 + sigma_t * sample
582- else :
583- raise ValueError (
584- f"prediction_type given as { self .config .prediction_type } must be one of `epsilon`, `sample`, or"
585- " `v_prediction`."
586- )
587- if self .config .thresholding :
588- x0_pred = self ._threshold_sample (x0_pred )
589- return x0_pred
590-
591530 def convert_model_output (
592531 self ,
593532 model_output : torch .Tensor ,
594533 * args ,
595534 sample : torch .Tensor = None ,
535+ predict_x0 : bool = True ,
536+ step_index : Optional [int ] = None ,
596537 ** kwargs ,
597538 ) -> torch .Tensor :
598539 r"""
@@ -622,11 +563,12 @@ def convert_model_output(
622563 "1.0.0" ,
623564 "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`" ,
624565 )
566+ step_index = step_index if step_index is not None else self .step_index
625567
626- sigma = self .sigmas [self . step_index ]
568+ sigma = self .sigmas [step_index ]
627569 alpha_t , sigma_t = self ._sigma_to_alpha_sigma_t (sigma )
628570
629- if self . predict_x0 :
571+ if predict_x0 :
630572 if self .config .prediction_type == "epsilon" :
631573 x0_pred = (sample - sigma_t * model_output ) / alpha_t
632574 elif self .config .prediction_type == "sample" :
@@ -996,7 +938,7 @@ def step(
996938 self .step_index > 0 and self .step_index - 1 not in self .disable_corrector and self .last_sample is not None
997939 )
998940
999- model_output_convert = self .convert_model_output (model_output , sample = sample )
941+ model_output_convert = self .convert_model_output (model_output , sample = sample , predict_x0 = self . predict_x0 )
1000942 if use_corrector :
1001943 sample = self .multistep_uni_c_bh_update (
1002944 this_model_output = model_output_convert ,
0 commit comments