6
6
7
7
8
8
class ScoreSdeVePipeline (DiffusionPipeline ):
9
- def __init__ (self , model , scheduler ):
9
+ def __init__ (self , unet , scheduler ):
10
10
super ().__init__ ()
11
- self .register_modules (model = model , scheduler = scheduler )
11
+ self .register_modules (unet = unet , scheduler = scheduler )
12
12
13
13
@torch .no_grad ()
14
14
def __call__ (self , batch_size = 1 , num_inference_steps = 2000 , generator = None , torch_device = None , output_type = "pil" ):
15
+
15
16
if torch_device is None :
16
17
torch_device = "cuda" if torch .cuda .is_available () else "cpu"
17
18
18
- img_size = self .model .config .sample_size
19
+ img_size = self .unet .config .sample_size
19
20
shape = (batch_size , 3 , img_size , img_size )
20
21
21
- model = self .model .to (torch_device )
22
+ model = self .unet .to (torch_device )
22
23
23
24
sample = torch .randn (* shape ) * self .scheduler .config .sigma_max
24
25
sample = sample .to (torch_device )
@@ -31,7 +32,7 @@ def __call__(self, batch_size=1, num_inference_steps=2000, generator=None, torch
31
32
32
33
# correction step
33
34
for _ in range (self .scheduler .correct_steps ):
34
- model_output = self .model (sample , sigma_t )["sample" ]
35
+ model_output = self .unet (sample , sigma_t )["sample" ]
35
36
sample = self .scheduler .step_correct (model_output , sample )["prev_sample" ]
36
37
37
38
# prediction step
@@ -40,7 +41,7 @@ def __call__(self, batch_size=1, num_inference_steps=2000, generator=None, torch
40
41
41
42
sample , sample_mean = output ["prev_sample" ], output ["prev_sample_mean" ]
42
43
43
- sample = sample .clamp (0 , 1 )
44
+ sample = sample_mean .clamp (0 , 1 )
44
45
sample = sample .cpu ().permute (0 , 2 , 3 , 1 ).numpy ()
45
46
if output_type == "pil" :
46
47
sample = self .numpy_to_pil (sample )
0 commit comments