@@ -11,22 +11,23 @@ def __init__(self, model, scheduler):
11
11
self .register_modules (model = model , scheduler = scheduler )
12
12
13
13
@torch .no_grad ()
14
- def __call__ (self , num_inference_steps = 2000 , generator = None , output_type = "pil" ):
15
- device = torch .device ("cuda" ) if torch .cuda .is_available () else torch .device ("cpu" )
14
+ def __call__ (self , batch_size = 1 , num_inference_steps = 2000 , generator = None , torch_device = None , output_type = "pil" ):
15
+ if torch_device is None :
16
+ torch_device = "cuda" if torch .cuda .is_available () else "cpu"
16
17
17
18
img_size = self .model .config .sample_size
18
- shape = (1 , 3 , img_size , img_size )
19
+ shape = (batch_size , 3 , img_size , img_size )
19
20
20
- model = self .model .to (device )
21
+ model = self .model .to (torch_device )
21
22
22
23
sample = torch .randn (* shape ) * self .scheduler .config .sigma_max
23
- sample = sample .to (device )
24
+ sample = sample .to (torch_device )
24
25
25
26
self .scheduler .set_timesteps (num_inference_steps )
26
27
self .scheduler .set_sigmas (num_inference_steps )
27
28
28
29
for i , t in tqdm (enumerate (self .scheduler .timesteps )):
29
- sigma_t = self .scheduler .sigmas [i ] * torch .ones (shape [0 ], device = device )
30
+ sigma_t = self .scheduler .sigmas [i ] * torch .ones (shape [0 ], device = torch_device )
30
31
31
32
# correction step
32
33
for _ in range (self .scheduler .correct_steps ):
0 commit comments