@@ -163,16 +163,22 @@ def encode_prompt(self, prompt, positive=True):
163163 return {"context" : prompt_emb }
164164
165165
166- def encode_image (self , image , num_frames , height , width ):
166+ def encode_image (self , image , end_image , num_frames , height , width ):
167167 image = self .preprocess_image (image .resize ((width , height ))).to (self .device )
168168 clip_context = self .image_encoder .encode_image ([image ])
169169 msk = torch .ones (1 , num_frames , height // 8 , width // 8 , device = self .device )
170170 msk [:, 1 :] = 0
171+ if end_image is not None :
172+ end_image = self .preprocess_image (end_image .resize ((width , height ))).to (self .device )
173+ vae_input = torch .concat ([image .transpose (0 ,1 ), torch .zeros (3 , num_frames - 2 , height , width ).to (image .device ), end_image .transpose (0 ,1 )],dim = 1 )
174+ msk [:, - 1 :] = 1
175+ else :
176+ vae_input = torch .concat ([image .transpose (0 , 1 ), torch .zeros (3 , num_frames - 1 , height , width ).to (image .device )], dim = 1 )
177+
171178 msk = torch .concat ([torch .repeat_interleave (msk [:, 0 :1 ], repeats = 4 , dim = 1 ), msk [:, 1 :]], dim = 1 )
172179 msk = msk .view (1 , msk .shape [1 ] // 4 , 4 , height // 8 , width // 8 )
173180 msk = msk .transpose (1 , 2 )[0 ]
174181
175- vae_input = torch .concat ([image .transpose (0 , 1 ), torch .zeros (3 , num_frames - 1 , height , width ).to (image .device )], dim = 1 )
176182 y = self .vae .encode ([vae_input .to (dtype = self .torch_dtype , device = self .device )], device = self .device )[0 ]
177183 y = torch .concat ([msk , y ])
178184 y = y .unsqueeze (0 )
@@ -212,6 +218,7 @@ def __call__(
212218 prompt ,
213219 negative_prompt = "" ,
214220 input_image = None ,
221+ end_image = None ,
215222 input_video = None ,
216223 denoising_strength = 1.0 ,
217224 seed = None ,
@@ -263,7 +270,7 @@ def __call__(
263270 # Encode image
264271 if input_image is not None and self .image_encoder is not None :
265272 self .load_models_to_device (["image_encoder" , "vae" ])
266- image_emb = self .encode_image (input_image , num_frames , height , width )
273+ image_emb = self .encode_image (input_image , end_image , num_frames , height , width )
267274 else :
268275 image_emb = {}
269276
0 commit comments