Skip to content

Commit abed76f

Browse files
authored
Merge pull request #77 from SkyworkAI/dev
Dev branch: support video extension and start/end frame control
2 parents 12c9c20 + be4ac62 commit abed76f

File tree

4 files changed

+348
-33
lines changed

4 files changed

+348
-33
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,5 @@ scripts/.gradio/*
1616
# *.csv
1717
*.jsonl
1818
out/*
19-
model/
19+
model/
20+
run.sh

README.md

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ Welcome to the **SkyReels V2** repository! Here, you'll find the model weights a
1313

1414

1515
## 🔥🔥🔥 News!!
16+
* May 16, 2025: 🔥 We release the inference code for [video extension](#ve) and [start/end frame control](#se) in diffusion forcing model.
1617
* Apr 24, 2025: 🔥 We release the 720P models, [SkyReels-V2-DF-14B-720P](https://huggingface.co/Skywork/SkyReels-V2-DF-14B-720P) and [SkyReels-V2-I2V-14B-720P](https://huggingface.co/Skywork/SkyReels-V2-I2V-14B-720P). The former facilitates infinite-length autoregressive video generation, and the latter focuses on Image2Video synthesis.
1718
* Apr 21, 2025: 👋 We release the inference code and model weights of [SkyReels-V2](https://huggingface.co/collections/Skywork/skyreels-v2-6801b1b93df627d441d0d0d9) Series Models and the video captioning model [SkyCaptioner-V1](https://huggingface.co/Skywork/SkyCaptioner-V1) .
1819
* Apr 3, 2025: 🔥 We also release [SkyReels-A2](https://github.com/SkyworkAI/SkyReels-A2). This is an open-sourced controllable video generation framework capable of assembling arbitrary visual elements.
@@ -222,6 +223,51 @@ python3 generate_video_df.py \
222223
> - `--addnoise_condition` is used to help smooth the long video generation by adding some noise to the clean condition. Too large noise can cause the inconsistency as well. 20 is a recommended value, and you may try larger ones, but it is recommended to not exceed 50.
223224
> - Generating a 540P video using the 1.3B model requires approximately 14.7GB peak VRAM, while the same resolution video using the 14B model demands around 51.2GB peak VRAM.
224225
226+
- **<span id="ve">Video Extention</span>**
227+
```shell
228+
model_id=Skywork/SkyReels-V2-DF-14B-540P
229+
# video extention
230+
python3 generate_video_df.py \
231+
--model_id ${model_id} \
232+
--resolution 540P \
233+
--ar_step 0 \
234+
--base_num_frames 97 \
235+
--num_frames 120 \
236+
--overlap_history 17 \
237+
--prompt ${prompt} \
238+
--addnoise_condition 20 \
239+
--offload \
240+
--use_ret_steps \
241+
--teacache \
242+
--teacache_thresh 0.3 \
243+
--video_path ${video_path}
244+
```
245+
> **Note**:
246+
> - When performing video extension, you need to pass the `--video_path ${video_path}` parameter to specify the video to be extended.
247+
248+
- **<span id="se">Start/End Frame Control</span>**
249+
```shell
250+
model_id=Skywork/SkyReels-V2-DF-14B-540P
251+
# start/end frame control
252+
python3 generate_video_df.py \
253+
--model_id ${model_id} \
254+
--resolution 540P \
255+
--ar_step 0 \
256+
--base_num_frames 97 \
257+
--num_frames 97 \
258+
--overlap_history 17 \
259+
--prompt ${prompt} \
260+
--addnoise_condition 20 \
261+
--offload \
262+
--use_ret_steps \
263+
--teacache \
264+
--teacache_thresh 0.3 \
265+
--image ${image} \
266+
--end_image ${end_image}
267+
```
268+
> **Note**:
269+
> - When controlling the start and end frames, you need to pass the `--image ${image}` parameter to control the generation of the start frame and the `--end_image ${end_image}` parameter to control the generation of the end frame.
270+
225271
- **Text To Video & Image To Video**
226272

227273
```shell
@@ -288,6 +334,8 @@ Below are the key parameters you can customize for video generation:
288334
| --overlap_history | 17 | Number of frames to overlap for smooth transitions in long videos |
289335
| --addnoise_condition | 20 | Improves consistency in long video generation |
290336
| --causal_block_size | 5 | Recommended when using asynchronous inference (--ar_step > 0) |
337+
--video_path | | Path to input video for video extension |
338+
--end_image | | Path to input image for end frame control |
291339

292340
#### Multi-GPU inference using xDiT USP
293341

generate_video_df.py

Lines changed: 66 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,27 @@
1111
from skyreels_v2_infer import DiffusionForcingPipeline
1212
from skyreels_v2_infer.modules import download_model
1313
from skyreels_v2_infer.pipelines import PromptEnhancer
14-
from skyreels_v2_infer.pipelines import resizecrop
14+
from skyreels_v2_infer.pipelines.image2video_pipeline import resizecrop
15+
from moviepy.editor import VideoFileClip
16+
17+
18+
def get_video_num_frames_moviepy(video_path):
19+
with VideoFileClip(video_path) as clip:
20+
num_frames = 0
21+
for _ in clip.iter_frames():
22+
num_frames += 1
23+
return clip.size, num_frames
1524

16-
if __name__ == "__main__":
1725

26+
if __name__ == "__main__":
1827
parser = argparse.ArgumentParser()
1928
parser.add_argument("--outdir", type=str, default="diffusion_forcing")
2029
parser.add_argument("--model_id", type=str, default="Skywork/SkyReels-V2-DF-1.3B-540P")
2130
parser.add_argument("--resolution", type=str, choices=["540P", "720P"])
2231
parser.add_argument("--num_frames", type=int, default=97)
2332
parser.add_argument("--image", type=str, default=None)
33+
parser.add_argument("--end_image", type=str, default=None)
34+
parser.add_argument("--video_path", type=str, default='')
2435
parser.add_argument("--ar_step", type=int, default=0)
2536
parser.add_argument("--causal_attention", action="store_true")
2637
parser.add_argument("--causal_block_size", type=int, default=1)
@@ -45,13 +56,11 @@
4556
"--teacache_thresh",
4657
type=float,
4758
default=0.2,
48-
help="Higher speedup will cause to worse quality -- 0.1 for 2.0x speedup -- 0.2 for 3.0x speedup",
49-
)
59+
help="Higher speedup will cause to worse quality -- 0.1 for 2.0x speedup -- 0.2 for 3.0x speedup")
5060
parser.add_argument(
5161
"--use_ret_steps",
5262
action="store_true",
53-
help="Using Retention Steps will result in faster generation speed and better generation quality.",
54-
)
63+
help="Using Retention Steps will result in faster generation speed and better generation quality.")
5564
args = parser.parse_args()
5665

5766
args.model_id = download_model(args.model_id)
@@ -85,22 +94,14 @@
8594

8695
guidance_scale = args.guidance_scale
8796
shift = args.shift
88-
if args.image:
89-
args.image = load_image(args.image)
90-
image_width, image_height = args.image.size
91-
if image_height > image_width:
92-
height, width = width, height
93-
args.image = resizecrop(args.image, height, width)
94-
image = args.image.convert("RGB") if args.image else None
97+
9598
negative_prompt = "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"
9699

97100
save_dir = os.path.join("result", args.outdir)
98101
os.makedirs(save_dir, exist_ok=True)
99102
local_rank = 0
100103
if args.use_usp:
101-
assert (
102-
not args.prompt_enhancer
103-
), "`--prompt_enhancer` is not allowed if using `--use_usp`. We recommend running the skyreels_v2_infer/pipelines/prompt_enhancer.py script first to generate enhanced prompt before enabling the `--use_usp` parameter."
104+
assert not args.prompt_enhancer, "`--prompt_enhancer` is not allowed if using `--use_usp`. We recommend running the skyreels_v2_infer/pipelines/prompt_enhancer.py script first to generate enhanced prompt before enabling the `--use_usp` parameter."
104105
from xfuser.core.distributed import initialize_model_parallel, init_distributed_environment
105106
import torch.distributed as dist
106107

@@ -138,32 +139,31 @@
138139

139140
if args.causal_attention:
140141
pipe.transformer.set_ar_attention(args.causal_block_size)
141-
142+
142143
if args.teacache:
143144
if args.ar_step > 0:
144-
num_steps = (
145-
args.inference_steps
146-
+ (((args.base_num_frames - 1) // 4 + 1) // args.causal_block_size - 1) * args.ar_step
147-
)
148-
print("num_steps:", num_steps)
145+
num_steps = args.inference_steps + (((args.base_num_frames - 1) // 4 + 1) // args.causal_block_size - 1) * args.ar_step
146+
print('num_steps:', num_steps)
149147
else:
150148
num_steps = args.inference_steps
151-
pipe.transformer.initialize_teacache(
152-
enable_teacache=True,
153-
num_steps=num_steps,
154-
teacache_thresh=args.teacache_thresh,
155-
use_ret_steps=args.use_ret_steps,
156-
ckpt_dir=args.model_id,
157-
)
149+
pipe.transformer.initialize_teacache(enable_teacache=True, num_steps=num_steps,
150+
teacache_thresh=args.teacache_thresh, use_ret_steps=args.use_ret_steps,
151+
ckpt_dir=args.model_id)
158152

159153
print(f"prompt:{prompt_input}")
160154
print(f"guidance_scale:{guidance_scale}")
161155

162-
with torch.cuda.amp.autocast(dtype=pipe.transformer.dtype), torch.no_grad():
163-
video_frames = pipe(
156+
if os.path.exists(args.video_path):
157+
(v_width, v_height), input_num_frames = get_video_num_frames_moviepy(args.video_path)
158+
assert input_num_frames >= args.overlap_history, "The input video is too short."
159+
160+
if v_height > v_width:
161+
width, heigth = height, width
162+
163+
video_frames = pipe.extend_video(
164164
prompt=prompt_input,
165165
negative_prompt=negative_prompt,
166-
image=image,
166+
prefix_video_path=args.video_path,
167167
height=height,
168168
width=width,
169169
num_frames=num_frames,
@@ -178,6 +178,40 @@
178178
causal_block_size=args.causal_block_size,
179179
fps=fps,
180180
)[0]
181+
else:
182+
if args.image:
183+
args.image = load_image(args.image)
184+
image_width, image_height = args.image.size
185+
if image_height > image_width:
186+
height, width = width, height
187+
args.image = resizecrop(args.image, height, width)
188+
if args.end_image:
189+
args.end_image = load_image(args.end_image)
190+
args.end_image = resizecrop(args.end_image, height, width)
191+
192+
image = args.image.convert("RGB") if args.image else None
193+
end_image = args.end_image.convert("RGB") if args.end_image else None
194+
195+
with torch.cuda.amp.autocast(dtype=pipe.transformer.dtype), torch.no_grad():
196+
video_frames = pipe(
197+
prompt=prompt_input,
198+
negative_prompt=negative_prompt,
199+
image=image,
200+
end_image=end_image,
201+
height=height,
202+
width=width,
203+
num_frames=num_frames,
204+
num_inference_steps=args.inference_steps,
205+
shift=shift,
206+
guidance_scale=guidance_scale,
207+
generator=torch.Generator(device="cuda").manual_seed(args.seed),
208+
overlap_history=args.overlap_history,
209+
addnoise_condition=args.addnoise_condition,
210+
base_num_frames=args.base_num_frames,
211+
ar_step=args.ar_step,
212+
causal_block_size=args.causal_block_size,
213+
fps=fps,
214+
)[0]
181215

182216
if local_rank == 0:
183217
current_time = time.strftime("%Y-%m-%d_%H-%M-%S", time.localtime())

0 commit comments

Comments
 (0)