@@ -11,22 +11,23 @@ def __init__(self, model, scheduler):
1111 self .register_modules (model = model , scheduler = scheduler )
1212
1313 @torch .no_grad ()
14- def __call__ (self , num_inference_steps = 2000 , generator = None , output_type = "pil" ):
15- device = torch .device ("cuda" ) if torch .cuda .is_available () else torch .device ("cpu" )
14+ def __call__ (self , batch_size = 1 , num_inference_steps = 2000 , generator = None , torch_device = None , output_type = "pil" ):
15+ if torch_device is None :
16+ torch_device = "cuda" if torch .cuda .is_available () else "cpu"
1617
1718 img_size = self .model .config .sample_size
18- shape = (1 , 3 , img_size , img_size )
19+ shape = (batch_size , 3 , img_size , img_size )
1920
20- model = self .model .to (device )
21+ model = self .model .to (torch_device )
2122
2223 sample = torch .randn (* shape ) * self .scheduler .config .sigma_max
23- sample = sample .to (device )
24+ sample = sample .to (torch_device )
2425
2526 self .scheduler .set_timesteps (num_inference_steps )
2627 self .scheduler .set_sigmas (num_inference_steps )
2728
2829 for i , t in tqdm (enumerate (self .scheduler .timesteps )):
29- sigma_t = self .scheduler .sigmas [i ] * torch .ones (shape [0 ], device = device )
30+ sigma_t = self .scheduler .sigmas [i ] * torch .ones (shape [0 ], device = torch_device )
3031
3132 # correction step
3233 for _ in range (self .scheduler .correct_steps ):
0 commit comments