@@ -15,36 +15,76 @@ class DiffusionSampler():
1515
1616 def __init__ (self , model :nn .Module , params :dict ,
1717 noise_schedule :NoiseScheduler ,
18- model_output_transform :DiffusionPredictionTransform = EpsilonPredictionTransform ()):
18+ model_output_transform :DiffusionPredictionTransform = EpsilonPredictionTransform (),
19+ guidance_scale :float = 0.0 ,
20+ null_labels_seq :jax .Array = None ,
21+ autoencoder = None ,
22+ image_size = 256 ,
23+ autoenc_scale_reduction = 8 ,
24+ autoenc_latent_channels = 4 ,
25+ ):
1926 self .model = model
2027 self .noise_schedule = noise_schedule
2128 self .params = params
2229 self .model_output_transform = model_output_transform
23-
24- @jax .jit
25- def sample_model (x_t , t ):
26- rates = self .noise_schedule .get_rates (t )
27- c_in = self .model_output_transform .get_input_scale (rates )
28- model_output = self .model .apply (self .params , * self .noise_schedule .transform_inputs (x_t * c_in , t ))
29- x_0 , eps = self .model_output_transform (x_t , model_output , t , self .noise_schedule )
30- return x_0 , eps , model_output
30+ self .guidance_scale = guidance_scale
31+ self .image_size = image_size
32+ self .autoenc_scale_reduction = autoenc_scale_reduction
33+ self .autoencoder = autoencoder
34+ self .autoenc_latent_channels = autoenc_latent_channels
3135
36+ if self .guidance_scale > 0 :
37+ # Classifier free guidance
38+ assert null_labels_seq is not None , "Null labels sequence is required for classifier-free guidance"
39+ print ("Using classifier-free guidance" )
40+ def sample_model (x_t , t , * additional_inputs ):
41+ # Concatenate unconditional and conditional inputs
42+ x_t_cat = jnp .concatenate ([x_t ] * 2 , axis = 0 )
43+ t_cat = jnp .concatenate ([t ] * 2 , axis = 0 )
44+ rates_cat = self .noise_schedule .get_rates (t_cat )
45+ c_in_cat = self .model_output_transform .get_input_scale (rates_cat )
46+
47+ text_labels_seq , = additional_inputs
48+ text_labels_seq = jnp .concatenate ([text_labels_seq , jnp .broadcast_to (null_labels_seq , text_labels_seq .shape )], axis = 0 )
49+ model_output = self .model .apply (self .params , * self .noise_schedule .transform_inputs (x_t_cat * c_in_cat , t_cat ), text_labels_seq )
50+ # Split model output into unconditional and conditional parts
51+ model_output_cond , model_output_uncond = jnp .split (model_output , 2 , axis = 0 )
52+ model_output = model_output_uncond + guidance_scale * (model_output_cond - model_output_uncond )
53+
54+ x_0 , eps = self .model_output_transform (x_t , model_output , t , self .noise_schedule )
55+ return x_0 , eps , model_output
56+ else :
57+ # Unconditional sampling
58+ def sample_model (x_t , t , * additional_inputs ):
59+ rates = self .noise_schedule .get_rates (t )
60+ c_in = self .model_output_transform .get_input_scale (rates )
61+ model_output = self .model .apply (self .params , * self .noise_schedule .transform_inputs (x_t * c_in , t ), * additional_inputs )
62+ x_0 , eps = self .model_output_transform (x_t , model_output , t , self .noise_schedule )
63+ return x_0 , eps , model_output
64+
65+ # if jax.device_count() > 1:
66+ # mesh = jax.sharding.Mesh(jax.devices(), 'data')
67+ # sample_model = shard_map(sample_model, mesh=mesh, in_specs=(P('data'), P('data'), P('data')),
68+ # out_specs=(P('data'), P('data'), P('data')))
69+ sample_model = jax .jit (sample_model )
3270 self .sample_model = sample_model
3371
3472 # Used to sample from the diffusion model
35- def sample_step (self , current_samples :jnp .ndarray , current_step , next_step = None , state :MarkovState = None ) -> tuple [jnp .ndarray , MarkovState ]:
73+ def sample_step (self , current_samples :jnp .ndarray , current_step , model_conditioning_inputs , next_step = None , state :MarkovState = None ) -> tuple [jnp .ndarray , MarkovState ]:
3674 # First clip the noisy images
37- # pred_images = clip_images(pred_images)
3875 step_ones = jnp .ones ((current_samples .shape [0 ], ), dtype = jnp .int32 )
3976 current_step = step_ones * current_step
4077 next_step = step_ones * next_step
41- pred_images , pred_noise , _ = self .sample_model (current_samples , current_step )
78+ pred_images , pred_noise , _ = self .sample_model (current_samples , current_step , * model_conditioning_inputs )
4279 # plotImages(pred_images)
80+ # pred_images = clip_images(pred_images)
4381 new_samples , state = self .take_next_step (current_samples = current_samples , reconstructed_samples = pred_images ,
44- pred_noise = pred_noise , current_step = current_step , next_step = next_step , state = state )
82+ pred_noise = pred_noise , current_step = current_step , next_step = next_step , state = state ,
83+ model_conditioning_inputs = model_conditioning_inputs
84+ )
4585 return new_samples , state
4686
47- def take_next_step (self , current_samples , reconstructed_samples ,
87+ def take_next_step (self , current_samples , reconstructed_samples , model_conditioning_inputs ,
4888 pred_noise , current_step , state :RandomMarkovState , next_step = 1 ) -> tuple [jnp .ndarray , RandomMarkovState ]:
4989 # estimate the q(x_{t-1} | x_t, x_0).
5090 # pred_images is x_0, noisy_images is x_t, steps is t
@@ -62,11 +102,16 @@ def get_steps(self, start_step, end_step, diffusion_steps):
62102 steps = jnp .linspace (end_step , start_step , diffusion_steps , dtype = jnp .int16 )[::- 1 ]
63103 return steps
64104
65- def get_initial_samples (self , num_images , rngs :jax .random .PRNGKey , start_step , image_size = 64 ):
105+ def get_initial_samples (self , num_images , rngs :jax .random .PRNGKey , start_step ):
66106 start_step = self .scale_steps (start_step )
67107 alpha_n , sigma_n = self .noise_schedule .get_rates (start_step )
68108 variance = jnp .sqrt (alpha_n ** 2 + sigma_n ** 2 )
69- return jax .random .normal (rngs , (num_images , image_size , image_size , 3 )) * variance
109+ image_size = self .image_size
110+ image_channels = 3
111+ if self .autoencoder is not None :
112+ image_size = image_size // self .autoenc_scale_reduction
113+ image_channels = self .autoenc_latent_channels
114+ return jax .random .normal (rngs , (num_images , image_size , image_size , image_channels )) * variance
70115
71116 def generate_images (self ,
72117 num_images = 16 ,
@@ -75,18 +120,23 @@ def generate_images(self,
75120 end_step :int = 0 ,
76121 steps_override = None ,
77122 priors = None ,
78- rngstate :RandomMarkovState = RandomMarkovState (jax .random .PRNGKey (42 ))) -> jnp .ndarray :
123+ rngstate :RandomMarkovState = RandomMarkovState (jax .random .PRNGKey (42 )),
124+ model_conditioning_inputs :tuple = ()
125+ ) -> jnp .ndarray :
79126 if priors is None :
80127 rngstate , newrngs = rngstate .get_random_key ()
81128 samples = self .get_initial_samples (num_images , newrngs , start_step )
82129 else :
83130 print ("Using priors" )
131+ if self .autoencoder is not None :
132+ priors = self .autoencoder .encode (priors )
84133 samples = priors
85134
86- @jax .jit
135+ # @jax.jit
87136 def sample_step (state :RandomMarkovState , samples , current_step , next_step ):
88137 samples , state = self .sample_step (current_samples = samples ,
89138 current_step = current_step ,
139+ model_conditioning_inputs = model_conditioning_inputs ,
90140 state = state , next_step = next_step )
91141 return samples , state
92142
@@ -108,6 +158,8 @@ def sample_step(state:RandomMarkovState, samples, current_step, next_step):
108158 else :
109159 # print("last step")
110160 step_ones = jnp .ones ((num_images , ), dtype = jnp .int32 )
111- samples , _ , _ = self .sample_model (samples , current_step * step_ones )
161+ samples , _ , _ = self .sample_model (samples , current_step * step_ones , * model_conditioning_inputs )
162+ if self .autoencoder is not None :
163+ samples = self .autoencoder .decode (samples )
112164 samples = clip_images (samples )
113- return samples
165+ return samples
0 commit comments