Skip to content

Commit 7c0a861

Browse files
committed
Add torch_device to the VE pipeline
1 parent a73ae3e commit 7c0a861

File tree

1 file changed

+7
-6
lines changed

1 file changed

+7
-6
lines changed

src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,22 +11,23 @@ def __init__(self, model, scheduler):
1111
self.register_modules(model=model, scheduler=scheduler)
1212

1313
@torch.no_grad()
14-
def __call__(self, num_inference_steps=2000, generator=None, output_type="pil"):
15-
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
14+
def __call__(self, batch_size=1, num_inference_steps=2000, generator=None, torch_device=None, output_type="pil"):
15+
if torch_device is None:
16+
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
1617

1718
img_size = self.model.config.sample_size
18-
shape = (1, 3, img_size, img_size)
19+
shape = (batch_size, 3, img_size, img_size)
1920

20-
model = self.model.to(device)
21+
model = self.model.to(torch_device)
2122

2223
sample = torch.randn(*shape) * self.scheduler.config.sigma_max
23-
sample = sample.to(device)
24+
sample = sample.to(torch_device)
2425

2526
self.scheduler.set_timesteps(num_inference_steps)
2627
self.scheduler.set_sigmas(num_inference_steps)
2728

2829
for i, t in tqdm(enumerate(self.scheduler.timesteps)):
29-
sigma_t = self.scheduler.sigmas[i] * torch.ones(shape[0], device=device)
30+
sigma_t = self.scheduler.sigmas[i] * torch.ones(shape[0], device=torch_device)
3031

3132
# correction step
3233
for _ in range(self.scheduler.correct_steps):

0 commit comments

Comments
 (0)