2121from bayesflow .utils .serialization import serialize , deserialize , serializable
2222
2323
24- @serializable
24+ @serializable ( "bayesflow.experimental" )
2525class DiffusionModel (InferenceNetwork ):
2626 """Diffusion Model as described in this overview paper [1].
2727
@@ -49,7 +49,7 @@ def __init__(
4949 * ,
5050 subnet : str | type = "mlp" ,
5151 integrate_kwargs : dict [str , any ] = None ,
52- noise_schedule : Literal ["edm" , "cosine" ] | type = "edm" ,
52+ noise_schedule : Literal ["edm" , "cosine" ] | dict | type = "edm" ,
5353 prediction_type : Literal ["velocity" , "noise" , "F" ] = "F" ,
5454 ** kwargs ,
5555 ):
@@ -69,8 +69,10 @@ def __init__(
6969 callable network. Default is "mlp".
7070 integrate_kwargs : dict[str, any], optional
7171 Additional keyword arguments for the integration process. Default is None.
72- noise_schedule : Literal['edm', 'cosine'] or type, optional
72+ noise_schedule : Literal['edm', 'cosine'], dict or type, optional
7373 The noise schedule used for the diffusion process. Can be "cosine" or "edm" or a custom noise schedule.
74+ You can also pass a dictionary with the configuration for the noise schedule, e.g.,
75+ {'type': cosine, 's_shift_cosine': 1.0}
7476 Default is "edm".
7577 prediction_type: Literal['velocity', 'noise', 'F'], optional
7678 The type of prediction used in the diffusion model. Can be "velocity", "noise" or "F" (EDM).
0 commit comments