Skip to content

Commit a98700f

Browse files
support wan-fun-inp generating
1 parent 5418ca7 commit a98700f

File tree

1 file changed

+10
-3
lines changed

1 file changed

+10
-3
lines changed

diffsynth/pipelines/wan_video.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -163,16 +163,22 @@ def encode_prompt(self, prompt, positive=True):
163163
return {"context": prompt_emb}
164164

165165

166-
def encode_image(self, image, num_frames, height, width):
166+
def encode_image(self, image, end_image, num_frames, height, width):
167167
image = self.preprocess_image(image.resize((width, height))).to(self.device)
168168
clip_context = self.image_encoder.encode_image([image])
169169
msk = torch.ones(1, num_frames, height//8, width//8, device=self.device)
170170
msk[:, 1:] = 0
171+
if end_image is not None:
172+
end_image = self.preprocess_image(end_image.resize((width, height))).to(self.device)
173+
vae_input = torch.concat([image.transpose(0,1), torch.zeros(3, num_frames-2, height, width).to(image.device), end_image.transpose(0,1)],dim=1)
174+
msk[:, -1:] = 1
175+
else:
176+
vae_input = torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1)
177+
171178
msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
172179
msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8)
173180
msk = msk.transpose(1, 2)[0]
174181

175-
vae_input = torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1)
176182
y = self.vae.encode([vae_input.to(dtype=self.torch_dtype, device=self.device)], device=self.device)[0]
177183
y = torch.concat([msk, y])
178184
y = y.unsqueeze(0)
@@ -212,6 +218,7 @@ def __call__(
212218
prompt,
213219
negative_prompt="",
214220
input_image=None,
221+
end_image=None,
215222
input_video=None,
216223
denoising_strength=1.0,
217224
seed=None,
@@ -263,7 +270,7 @@ def __call__(
263270
# Encode image
264271
if input_image is not None and self.image_encoder is not None:
265272
self.load_models_to_device(["image_encoder", "vae"])
266-
image_emb = self.encode_image(input_image, num_frames, height, width)
273+
image_emb = self.encode_image(input_image, end_image, num_frames, height, width)
267274
else:
268275
image_emb = {}
269276

0 commit comments

Comments
 (0)