@@ -45,18 +45,6 @@ def __init__(self, name: str, variance_type: str):
4545 self .variance_type = variance_type # 'exploding' or 'preserving'
4646 self ._log_snr_min = - 15 # should be set in the subclasses
4747 self ._log_snr_max = 15 # should be set in the subclasses
48- self .sigma_data = 1.0
49-
50- @property
51- def scale_base_distribution (self ):
52- """Get the scale of the base distribution."""
53- if self .variance_type == "preserving" :
54- return 1.0
55- elif self .variance_type == "exploding" :
56- # e.g., EDM is a variance exploding schedule
57- return ops .sqrt (ops .exp (- self ._log_snr_min ))
58- else :
59- raise ValueError (f"Unknown variance type: { self .variance_type } " )
6048
6149 @abstractmethod
6250 def get_log_snr (self , t : Union [float , Tensor ], training : bool ) -> Tensor :
@@ -106,8 +94,8 @@ def get_alpha_sigma(self, log_snr_t: Tensor, training: bool) -> tuple[Tensor, Te
10694 """
10795 if self .variance_type == "preserving" :
10896 # variance preserving schedule
109- alpha_t = keras . ops .sqrt (keras . ops .sigmoid (log_snr_t ))
110- sigma_t = keras . ops .sqrt (keras . ops .sigmoid (- log_snr_t ))
97+ alpha_t = ops .sqrt (ops .sigmoid (log_snr_t ))
98+ sigma_t = ops .sqrt (ops .sigmoid (- log_snr_t ))
11199 elif self .variance_type == "exploding" :
112100 # variance exploding schedule
113101 alpha_t = ops .ones_like (log_snr_t )
@@ -271,6 +259,7 @@ def from_config(cls, config, custom_objects=None):
271259class EDMNoiseSchedule (NoiseSchedule ):
272260 """EDM noise schedule for diffusion models. This schedule is based on the EDM paper [1].
273261 This should be used with the F-prediction type in the diffusion model.
262+ Since the schedule is variance exploding, the base distribution is a Gaussian with scale 'sigma_max'.
274263
275264 [1] Elucidating the Design Space of Diffusion-Based Generative Models: Karras et al. (2022)
276265 """
@@ -301,7 +290,7 @@ def get_log_snr(self, t: Union[float, Tensor], training: bool) -> Tensor:
301290 loc = - 2 * self .p_mean
302291 scale = 2 * self .p_std
303292 snr = - (loc + scale * ops .erfinv (2 * t - 1 ) * math .sqrt (2 ))
304- snr = keras . ops .clip (snr , x_min = self ._log_snr_min_training , x_max = self ._log_snr_max_training )
293+ snr = ops .clip (snr , x_min = self ._log_snr_min_training , x_max = self ._log_snr_max_training )
305294 else : # sampling
306295 sigma_min_rho = self .sigma_min ** (1 / self .rho )
307296 sigma_max_rho = self .sigma_max ** (1 / self .rho )
@@ -375,7 +364,7 @@ class DiffusionModel(InferenceNetwork):
375364
376365 INTEGRATE_DEFAULT_CONFIG = {
377366 "method" : "euler" , # or euler_maruyama
378- "steps" : 100 ,
367+ "steps" : 250 ,
379368 }
380369
381370 def __init__ (
@@ -444,9 +433,7 @@ def __init__(
444433 self ._clip_max = 5.0
445434
446435 # latent distribution (not configurable)
447- self .base_distribution = bf .distributions .DiagonalNormal (
448- mean = 0.0 , std = self .noise_schedule .scale_base_distribution
449- )
436+ self .base_distribution = bf .distributions .DiagonalNormal ()
450437 self .integrate_kwargs = self .INTEGRATE_DEFAULT_CONFIG | (integrate_kwargs or {})
451438 self .seed_generator = keras .random .SeedGenerator ()
452439
@@ -521,7 +508,7 @@ def convert_prediction_to_x(
521508 x = (z + sigma_t ** 2 * pred ) / alpha_t
522509
523510 if clip_x :
524- x = keras . ops .clip (x , self ._clip_min , self ._clip_max )
511+ x = ops .clip (x , self ._clip_min , self ._clip_max )
525512 return x
526513
527514 def velocity (
@@ -535,13 +522,13 @@ def velocity(
535522 ) -> Tensor :
536523 # calculate the current noise level and transform into correct shape
537524 log_snr_t = expand_right_as (self .noise_schedule .get_log_snr (t = time , training = training ), xz )
538- log_snr_t = keras . ops .broadcast_to (log_snr_t , keras . ops .shape (xz )[:- 1 ] + (1 ,))
525+ log_snr_t = ops .broadcast_to (log_snr_t , ops .shape (xz )[:- 1 ] + (1 ,))
539526 alpha_t , sigma_t = self .noise_schedule .get_alpha_sigma (log_snr_t = log_snr_t , training = training )
540527
541528 if conditions is None :
542- xtc = keras . ops .concatenate ([xz , log_snr_t ], axis = - 1 )
529+ xtc = ops .concatenate ([xz , log_snr_t ], axis = - 1 )
543530 else :
544- xtc = keras . ops .concatenate ([xz , log_snr_t , conditions ], axis = - 1 )
531+ xtc = ops .concatenate ([xz , log_snr_t , conditions ], axis = - 1 )
545532 pred = self .output_projector (self .subnet (xtc , training = training ), training = training )
546533
547534 x_pred = self .convert_prediction_to_x (
@@ -570,7 +557,7 @@ def compute_diffusion_term(
570557 ) -> Tensor :
571558 # calculate the current noise level and transform into correct shape
572559 log_snr_t = expand_right_as (self .noise_schedule .get_log_snr (t = time , training = training ), xz )
573- log_snr_t = keras . ops .broadcast_to (log_snr_t , keras . ops .shape (xz )[:- 1 ] + (1 ,))
560+ log_snr_t = ops .broadcast_to (log_snr_t , ops .shape (xz )[:- 1 ] + (1 ,))
574561 g_squared = self .noise_schedule .get_drift_diffusion (log_snr_t = log_snr_t )
575562 return ops .sqrt (g_squared )
576563
@@ -587,7 +574,7 @@ def f(x):
587574
588575 v , trace = jacobian_trace (f , xz , max_steps = max_steps , seed = self .seed_generator , return_output = True )
589576
590- return v , keras . ops .expand_dims (trace , axis = - 1 )
577+ return v , ops .expand_dims (trace , axis = - 1 )
591578
592579 def _forward (
593580 self ,
@@ -616,7 +603,7 @@ def deltas(time, xz):
616603
617604 state = {
618605 "xz" : x ,
619- "trace" : keras . ops .zeros (keras . ops .shape (x )[:- 1 ] + (1 ,), dtype = keras . ops .dtype (x )),
606+ "trace" : ops .zeros (ops .shape (x )[:- 1 ] + (1 ,), dtype = ops .dtype (x )),
620607 }
621608 state = integrate (
622609 deltas ,
@@ -625,7 +612,7 @@ def deltas(time, xz):
625612 )
626613
627614 z = state ["xz" ]
628- log_density = self .base_distribution .log_prob (z ) + keras . ops .squeeze (state ["trace" ], axis = - 1 )
615+ log_density = self .base_distribution .log_prob (z ) + ops .squeeze (state ["trace" ], axis = - 1 )
629616
630617 return z , log_density
631618
@@ -669,12 +656,12 @@ def deltas(time, xz):
669656
670657 state = {
671658 "xz" : z ,
672- "trace" : keras . ops .zeros (keras . ops .shape (z )[:- 1 ] + (1 ,), dtype = keras . ops .dtype (z )),
659+ "trace" : ops .zeros (ops .shape (z )[:- 1 ] + (1 ,), dtype = ops .dtype (z )),
673660 }
674661 state = integrate (deltas , state , ** integrate_kwargs )
675662
676663 x = state ["xz" ]
677- log_density = self .base_distribution .log_prob (z ) - keras . ops .squeeze (state ["trace" ], axis = - 1 )
664+ log_density = self .base_distribution .log_prob (z ) - ops .squeeze (state ["trace" ], axis = - 1 )
678665
679666 return x , log_density
680667
@@ -723,17 +710,17 @@ def compute_metrics(
723710 training = stage == "training"
724711 noise_schedule_training_stage = stage == "training" or stage == "validation"
725712 if not self .built :
726- xz_shape = keras . ops .shape (x )
727- conditions_shape = None if conditions is None else keras . ops .shape (conditions )
713+ xz_shape = ops .shape (x )
714+ conditions_shape = None if conditions is None else ops .shape (conditions )
728715 self .build (xz_shape , conditions_shape )
729716
730717 # sample training diffusion time as low discrepancy sequence to decrease variance
731718 # t_i = \mod (u_0 + i/k, 1)
732719 u0 = keras .random .uniform (shape = (1 ,), dtype = ops .dtype (x ), seed = self .seed_generator )
733- i = ops .arange (0 , keras . ops .shape (x )[0 ], dtype = ops .dtype (x )) # tensor of indices
734- t = (u0 + i / ops .cast (keras . ops .shape (x )[0 ], dtype = ops .dtype (x ))) % 1
735- # i = keras.random.randint((keras. ops.shape(x)[0],), minval=0, maxval=self._timesteps)
736- # t = keras. ops.cast(i, keras. ops.dtype(x)) / keras. ops.cast(self._timesteps, keras. ops.dtype(x))
720+ i = ops .arange (0 , ops .shape (x )[0 ], dtype = ops .dtype (x )) # tensor of indices
721+ t = (u0 + i / ops .cast (ops .shape (x )[0 ], dtype = ops .dtype (x ))) % 1
722+ # i = keras.random.randint((ops.shape(x)[0],), minval=0, maxval=self._timesteps)
723+ # t = ops.cast(i, ops.dtype(x)) / ops.cast(self._timesteps, ops.dtype(x))
737724
738725 # calculate the noise level
739726 log_snr_t = expand_right_as (self .noise_schedule .get_log_snr (t , training = noise_schedule_training_stage ), x )
@@ -749,9 +736,9 @@ def compute_metrics(
749736
750737 # calculate output of the network
751738 if conditions is None :
752- xtc = keras . ops .concatenate ([diffused_x , log_snr_t ], axis = - 1 )
739+ xtc = ops .concatenate ([diffused_x , log_snr_t ], axis = - 1 )
753740 else :
754- xtc = keras . ops .concatenate ([diffused_x , log_snr_t , conditions ], axis = - 1 )
741+ xtc = ops .concatenate ([diffused_x , log_snr_t , conditions ], axis = - 1 )
755742 pred = self .output_projector (self .subnet (xtc , training = training ), training = training )
756743
757744 x_pred = self .convert_prediction_to_x (
0 commit comments