Skip to content

Commit aa2598a

Browse files
authored
fix image size compatibility (#16)
1 parent 4dedfce commit aa2598a

File tree

3 files changed

+28
-2
lines changed

3 files changed

+28
-2
lines changed

generate_video.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from skyreels_v2_infer.modules import download_model
1212
from skyreels_v2_infer.pipelines import Image2VideoPipeline
1313
from skyreels_v2_infer.pipelines import PromptEnhancer
14+
from skyreels_v2_infer.pipelines import resizecrop
1415
from skyreels_v2_infer.pipelines import Text2VideoPipeline
1516

1617
MODEL_ID_CONFIG = {
@@ -109,6 +110,11 @@
109110
pipe = Image2VideoPipeline(
110111
model_path=args.model_id, dit_path=args.model_id, use_usp=args.use_usp, offload=args.offload
111112
)
113+
args.image = load_image(args.image)
114+
image_width, image_height = args.image.size
115+
if image_height > image_width:
116+
height, width = width, height
117+
args.image = resizecrop(args.image, height, width)
112118

113119
prompt_input = args.prompt
114120
if args.prompt_enhancer and image is not None:
@@ -128,7 +134,7 @@
128134
}
129135

130136
if image is not None:
131-
kwargs["image"] = load_image(args.image).convert("RGB")
137+
kwargs["image"] = args.image.convert("RGB")
132138

133139
save_dir = os.path.join("result", args.outdir)
134140
os.makedirs(save_dir, exist_ok=True)
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from .diffusion_forcing_pipeline import DiffusionForcingPipeline
2-
from .text2video_pipeline import Text2VideoPipeline
32
from .image2video_pipeline import Image2VideoPipeline
3+
from .image2video_pipeline import resizecrop
44
from .prompt_enhancer import PromptEnhancer
5+
from .text2video_pipeline import Text2VideoPipeline

skyreels_v2_infer/pipelines/image2video_pipeline.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import torch
88
from diffusers.image_processor import PipelineImageInput
99
from diffusers.video_processor import VideoProcessor
10+
from PIL import Image
1011
from tqdm import tqdm
1112

1213
from ..modules import get_image_encoder
@@ -16,6 +17,24 @@
1617
from ..scheduler.fm_solvers_unipc import FlowUniPCMultistepScheduler
1718

1819

20+
def resizecrop(image: Image.Image, th, tw):
21+
w, h = image.size
22+
if w == tw and h == th:
23+
return image
24+
if h / w > th / tw:
25+
new_w = int(w)
26+
new_h = int(new_w * th / tw)
27+
else:
28+
new_h = int(h)
29+
new_w = int(new_h * tw / th)
30+
left = (w - new_w) / 2
31+
top = (h - new_h) / 2
32+
right = (w + new_w) / 2
33+
bottom = (h + new_h) / 2
34+
image = image.crop((left, top, right, bottom))
35+
return image
36+
37+
1938
class Image2VideoPipeline:
2039
def __init__(
2140
self, model_path, dit_path, device: str = "cuda", weight_dtype=torch.bfloat16, use_usp=False, offload=False

0 commit comments

Comments
 (0)