Skip to content

Commit 5311f56

Browse files
Final fixes (#118)
final fixes before release
1 parent 3b7f514 commit 5311f56

File tree

2 files changed

+11
-13
lines changed

2 files changed

+11
-13
lines changed

src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,20 @@
66

77

88
class ScoreSdeVePipeline(DiffusionPipeline):
9-
def __init__(self, model, scheduler):
9+
def __init__(self, unet, scheduler):
1010
super().__init__()
11-
self.register_modules(model=model, scheduler=scheduler)
11+
self.register_modules(unet=unet, scheduler=scheduler)
1212

1313
@torch.no_grad()
1414
def __call__(self, batch_size=1, num_inference_steps=2000, generator=None, torch_device=None, output_type="pil"):
15+
1516
if torch_device is None:
1617
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
1718

18-
img_size = self.model.config.sample_size
19+
img_size = self.unet.config.sample_size
1920
shape = (batch_size, 3, img_size, img_size)
2021

21-
model = self.model.to(torch_device)
22+
model = self.unet.to(torch_device)
2223

2324
sample = torch.randn(*shape) * self.scheduler.config.sigma_max
2425
sample = sample.to(torch_device)
@@ -31,7 +32,7 @@ def __call__(self, batch_size=1, num_inference_steps=2000, generator=None, torch
3132

3233
# correction step
3334
for _ in range(self.scheduler.correct_steps):
34-
model_output = self.model(sample, sigma_t)["sample"]
35+
model_output = self.unet(sample, sigma_t)["sample"]
3536
sample = self.scheduler.step_correct(model_output, sample)["prev_sample"]
3637

3738
# prediction step
@@ -40,7 +41,7 @@ def __call__(self, batch_size=1, num_inference_steps=2000, generator=None, torch
4041

4142
sample, sample_mean = output["prev_sample"], output["prev_sample_mean"]
4243

43-
sample = sample.clamp(0, 1)
44+
sample = sample_mean.clamp(0, 1)
4445
sample = sample.cpu().permute(0, 2, 3, 1).numpy()
4546
if output_type == "pil":
4647
sample = self.numpy_to_pil(sample)

tests/test_modeling_utils.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -848,15 +848,12 @@ def test_ldm_text2img_fast(self):
848848

849849
@slow
850850
def test_score_sde_ve_pipeline(self):
851-
model = UNet2DModel.from_pretrained("google/ncsnpp-church-256")
851+
model_id = "google/ncsnpp-church-256"
852+
model = UNet2DModel.from_pretrained(model_id)
852853

853-
torch.manual_seed(0)
854-
if torch.cuda.is_available():
855-
torch.cuda.manual_seed_all(0)
856-
857-
scheduler = ScoreSdeVeScheduler.from_config("google/ncsnpp-church-256")
854+
scheduler = ScoreSdeVeScheduler.from_config(model_id)
858855

859-
sde_ve = ScoreSdeVePipeline(model=model, scheduler=scheduler)
856+
sde_ve = ScoreSdeVePipeline(unet=model, scheduler=scheduler)
860857

861858
torch.manual_seed(0)
862859
image = sde_ve(num_inference_steps=300, output_type="numpy")["sample"]

0 commit comments

Comments
 (0)