|
11 | 11 | from skyreels_v2_infer import DiffusionForcingPipeline |
12 | 12 | from skyreels_v2_infer.modules import download_model |
13 | 13 | from skyreels_v2_infer.pipelines import PromptEnhancer |
14 | | -from skyreels_v2_infer.pipelines import resizecrop |
| 14 | +from skyreels_v2_infer.pipelines.image2video_pipeline import resizecrop |
| 15 | +from moviepy.editor import VideoFileClip |
| 16 | + |
| 17 | + |
| 18 | +def get_video_num_frames_moviepy(video_path): |
| 19 | + with VideoFileClip(video_path) as clip: |
| 20 | + num_frames = 0 |
| 21 | + for _ in clip.iter_frames(): |
| 22 | + num_frames += 1 |
| 23 | + return clip.size, num_frames |
15 | 24 |
|
16 | | -if __name__ == "__main__": |
17 | 25 |
|
| 26 | +if __name__ == "__main__": |
18 | 27 | parser = argparse.ArgumentParser() |
19 | 28 | parser.add_argument("--outdir", type=str, default="diffusion_forcing") |
20 | 29 | parser.add_argument("--model_id", type=str, default="Skywork/SkyReels-V2-DF-1.3B-540P") |
21 | 30 | parser.add_argument("--resolution", type=str, choices=["540P", "720P"]) |
22 | 31 | parser.add_argument("--num_frames", type=int, default=97) |
23 | 32 | parser.add_argument("--image", type=str, default=None) |
| 33 | + parser.add_argument("--end_image", type=str, default=None) |
| 34 | + parser.add_argument("--video_path", type=str, default='') |
24 | 35 | parser.add_argument("--ar_step", type=int, default=0) |
25 | 36 | parser.add_argument("--causal_attention", action="store_true") |
26 | 37 | parser.add_argument("--causal_block_size", type=int, default=1) |
|
45 | 56 | "--teacache_thresh", |
46 | 57 | type=float, |
47 | 58 | default=0.2, |
48 | | - help="Higher speedup will cause to worse quality -- 0.1 for 2.0x speedup -- 0.2 for 3.0x speedup", |
49 | | - ) |
| 59 | + help="Higher speedup will cause to worse quality -- 0.1 for 2.0x speedup -- 0.2 for 3.0x speedup") |
50 | 60 | parser.add_argument( |
51 | 61 | "--use_ret_steps", |
52 | 62 | action="store_true", |
53 | | - help="Using Retention Steps will result in faster generation speed and better generation quality.", |
54 | | - ) |
| 63 | + help="Using Retention Steps will result in faster generation speed and better generation quality.") |
55 | 64 | args = parser.parse_args() |
56 | 65 |
|
57 | 66 | args.model_id = download_model(args.model_id) |
|
85 | 94 |
|
86 | 95 | guidance_scale = args.guidance_scale |
87 | 96 | shift = args.shift |
88 | | - if args.image: |
89 | | - args.image = load_image(args.image) |
90 | | - image_width, image_height = args.image.size |
91 | | - if image_height > image_width: |
92 | | - height, width = width, height |
93 | | - args.image = resizecrop(args.image, height, width) |
94 | | - image = args.image.convert("RGB") if args.image else None |
| 97 | + |
95 | 98 | negative_prompt = "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" |
96 | 99 |
|
97 | 100 | save_dir = os.path.join("result", args.outdir) |
98 | 101 | os.makedirs(save_dir, exist_ok=True) |
99 | 102 | local_rank = 0 |
100 | 103 | if args.use_usp: |
101 | | - assert ( |
102 | | - not args.prompt_enhancer |
103 | | - ), "`--prompt_enhancer` is not allowed if using `--use_usp`. We recommend running the skyreels_v2_infer/pipelines/prompt_enhancer.py script first to generate enhanced prompt before enabling the `--use_usp` parameter." |
| 104 | + assert not args.prompt_enhancer, "`--prompt_enhancer` is not allowed if using `--use_usp`. We recommend running the skyreels_v2_infer/pipelines/prompt_enhancer.py script first to generate enhanced prompt before enabling the `--use_usp` parameter." |
104 | 105 | from xfuser.core.distributed import initialize_model_parallel, init_distributed_environment |
105 | 106 | import torch.distributed as dist |
106 | 107 |
|
|
138 | 139 |
|
139 | 140 | if args.causal_attention: |
140 | 141 | pipe.transformer.set_ar_attention(args.causal_block_size) |
141 | | - |
| 142 | + |
142 | 143 | if args.teacache: |
143 | 144 | if args.ar_step > 0: |
144 | | - num_steps = ( |
145 | | - args.inference_steps |
146 | | - + (((args.base_num_frames - 1) // 4 + 1) // args.causal_block_size - 1) * args.ar_step |
147 | | - ) |
148 | | - print("num_steps:", num_steps) |
| 145 | + num_steps = args.inference_steps + (((args.base_num_frames - 1) // 4 + 1) // args.causal_block_size - 1) * args.ar_step |
| 146 | + print('num_steps:', num_steps) |
149 | 147 | else: |
150 | 148 | num_steps = args.inference_steps |
151 | | - pipe.transformer.initialize_teacache( |
152 | | - enable_teacache=True, |
153 | | - num_steps=num_steps, |
154 | | - teacache_thresh=args.teacache_thresh, |
155 | | - use_ret_steps=args.use_ret_steps, |
156 | | - ckpt_dir=args.model_id, |
157 | | - ) |
| 149 | + pipe.transformer.initialize_teacache(enable_teacache=True, num_steps=num_steps, |
| 150 | + teacache_thresh=args.teacache_thresh, use_ret_steps=args.use_ret_steps, |
| 151 | + ckpt_dir=args.model_id) |
158 | 152 |
|
159 | 153 | print(f"prompt:{prompt_input}") |
160 | 154 | print(f"guidance_scale:{guidance_scale}") |
161 | 155 |
|
162 | | - with torch.cuda.amp.autocast(dtype=pipe.transformer.dtype), torch.no_grad(): |
163 | | - video_frames = pipe( |
| 156 | + if os.path.exists(args.video_path): |
| 157 | + (v_width, v_height), input_num_frames = get_video_num_frames_moviepy(args.video_path) |
| 158 | + assert input_num_frames >= args.overlap_history, "The input video is too short." |
| 159 | + |
| 160 | + if v_height > v_width: |
| 161 | + width, heigth = height, width |
| 162 | + |
| 163 | + video_frames = pipe.extend_video( |
164 | 164 | prompt=prompt_input, |
165 | 165 | negative_prompt=negative_prompt, |
166 | | - image=image, |
| 166 | + prefix_video_path=args.video_path, |
167 | 167 | height=height, |
168 | 168 | width=width, |
169 | 169 | num_frames=num_frames, |
|
178 | 178 | causal_block_size=args.causal_block_size, |
179 | 179 | fps=fps, |
180 | 180 | )[0] |
| 181 | + else: |
| 182 | + if args.image: |
| 183 | + args.image = load_image(args.image) |
| 184 | + image_width, image_height = args.image.size |
| 185 | + if image_height > image_width: |
| 186 | + height, width = width, height |
| 187 | + args.image = resizecrop(args.image, height, width) |
| 188 | + if args.end_image: |
| 189 | + args.end_image = load_image(args.end_image) |
| 190 | + args.end_image = resizecrop(args.end_image, height, width) |
| 191 | + |
| 192 | + image = args.image.convert("RGB") if args.image else None |
| 193 | + end_image = args.end_image.convert("RGB") if args.end_image else None |
| 194 | + |
| 195 | + with torch.cuda.amp.autocast(dtype=pipe.transformer.dtype), torch.no_grad(): |
| 196 | + video_frames = pipe( |
| 197 | + prompt=prompt_input, |
| 198 | + negative_prompt=negative_prompt, |
| 199 | + image=image, |
| 200 | + end_image=end_image, |
| 201 | + height=height, |
| 202 | + width=width, |
| 203 | + num_frames=num_frames, |
| 204 | + num_inference_steps=args.inference_steps, |
| 205 | + shift=shift, |
| 206 | + guidance_scale=guidance_scale, |
| 207 | + generator=torch.Generator(device="cuda").manual_seed(args.seed), |
| 208 | + overlap_history=args.overlap_history, |
| 209 | + addnoise_condition=args.addnoise_condition, |
| 210 | + base_num_frames=args.base_num_frames, |
| 211 | + ar_step=args.ar_step, |
| 212 | + causal_block_size=args.causal_block_size, |
| 213 | + fps=fps, |
| 214 | + )[0] |
181 | 215 |
|
182 | 216 | if local_rank == 0: |
183 | 217 | current_time = time.strftime("%Y-%m-%d_%H-%M-%S", time.localtime()) |
|
0 commit comments