@@ -374,8 +374,8 @@ def __init__(
374374 subnet : str | type = "mlp" ,
375375 integrate_kwargs : dict [str , any ] = None ,
376376 subnet_kwargs : dict [str , any ] = None ,
377- noise_schedule : str | NoiseSchedule = "cosine " ,
378- prediction_type : str = "velocity " ,
377+ noise_schedule : str | NoiseSchedule = "edm " ,
378+ prediction_type : str = "F " ,
379379 ** kwargs ,
380380 ):
381381 """
@@ -398,10 +398,10 @@ def __init__(
398398 Keyword arguments passed to the subnet constructor or used to update the default MLP settings.
399399 noise_schedule : str or NoiseSchedule, optional
400400 The noise schedule used for the diffusion process. Can be "linear", "cosine", or "edm".
401- Default is "cosine ".
401+ Default is "edm ".
402402 prediction_type: str, optional
403403 The type of prediction used in the diffusion model. Can be "velocity", "noise" or "F" (EDM).
404- Default is "velocity ".
404+ Default is "F ".
405405 **kwargs
406406 Additional keyword arguments passed to the subnet and other components.
407407 """
@@ -425,10 +425,6 @@ def __init__(
425425 if prediction_type not in ["noise" , "velocity" , "F" ]: # F is EDM
426426 raise ValueError (f"Unknown prediction type: { prediction_type } " )
427427 self ._prediction_type = prediction_type
428- if noise_schedule .name == "edm_noise_schedule" and prediction_type != "F" :
429- warnings .warn (
430- "EDM noise schedule is build for F-prediction. Consider using F-prediction instead." ,
431- )
432428 self ._loss_type = kwargs .get ("loss_type" , "noise" )
433429 if self ._loss_type not in ["noise" , "velocity" , "F" ]:
434430 raise ValueError (f"Unknown loss type: { self ._loss_type } " )
0 commit comments