@@ -248,7 +248,13 @@ def _set_state_dict_into_text_encoder(
248248
249249
250250def  compute_density_for_timestep_sampling (
251-     weighting_scheme : str , batch_size : int , logit_mean : float  =  None , logit_std : float  =  None , mode_scale : float  =  None 
251+     weighting_scheme : str ,
252+     batch_size : int ,
253+     logit_mean : float  =  None ,
254+     logit_std : float  =  None ,
255+     mode_scale : float  =  None ,
256+     device : torch .device  =  "cpu" ,
257+     generator : Optional [torch .Generator ] =  None ,
252258):
253259    """ 
254260    Compute the density for sampling the timesteps when doing SD3 training. 
@@ -258,14 +264,13 @@ def compute_density_for_timestep_sampling(
258264    SD3 paper reference: https://arxiv.org/abs/2403.03206v1. 
259265    """ 
260266    if  weighting_scheme  ==  "logit_normal" :
261-         # See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$). 
262-         u  =  torch .normal (mean = logit_mean , std = logit_std , size = (batch_size ,), device = "cpu" )
267+         u  =  torch .normal (mean = logit_mean , std = logit_std , size = (batch_size ,), device = device , generator = generator )
263268        u  =  torch .nn .functional .sigmoid (u )
264269    elif  weighting_scheme  ==  "mode" :
265-         u  =  torch .rand (size = (batch_size ,), device = "cpu" )
270+         u  =  torch .rand (size = (batch_size ,), device = device ,  generator = generator )
266271        u  =  1  -  u  -  mode_scale  *  (torch .cos (math .pi  *  u  /  2 ) **  2  -  1  +  u )
267272    else :
268-         u  =  torch .rand (size = (batch_size ,), device = "cpu" )
273+         u  =  torch .rand (size = (batch_size ,), device = device ,  generator = generator )
269274    return  u 
270275
271276
0 commit comments