Skip to content

Commit 0275b50

Browse files
authored
Merge pull request #50 from Howe2018/main
fix df image size adaptation && Multi-GPU teacache support
2 parents 3a164c1 + 0d4a9a9 commit 0275b50

File tree

2 files changed

+110
-16
lines changed

2 files changed

+110
-16
lines changed

generate_video_df.py

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from skyreels_v2_infer import DiffusionForcingPipeline
1212
from skyreels_v2_infer.modules import download_model
1313
from skyreels_v2_infer.pipelines import PromptEnhancer
14+
from skyreels_v2_infer.pipelines import resizecrop
1415

1516
if __name__ == "__main__":
1617

@@ -44,11 +45,13 @@
4445
"--teacache_thresh",
4546
type=float,
4647
default=0.2,
47-
help="Higher speedup will cause to worse quality -- 0.1 for 2.0x speedup -- 0.2 for 3.0x speedup")
48+
help="Higher speedup will cause to worse quality -- 0.1 for 2.0x speedup -- 0.2 for 3.0x speedup",
49+
)
4850
parser.add_argument(
4951
"--use_ret_steps",
5052
action="store_true",
51-
help="Using Retention Steps will result in faster generation speed and better generation quality.")
53+
help="Using Retention Steps will result in faster generation speed and better generation quality.",
54+
)
5255
args = parser.parse_args()
5356

5457
args.model_id = download_model(args.model_id)
@@ -82,14 +85,22 @@
8285

8386
guidance_scale = args.guidance_scale
8487
shift = args.shift
85-
image = load_image(args.image).convert("RGB") if args.image else None
88+
if args.image:
89+
args.image = load_image(args.image)
90+
image_width, image_height = args.image.size
91+
if image_height > image_width:
92+
height, width = width, height
93+
args.image = resizecrop(args.image, height, width)
94+
image = args.image.convert("RGB") if args.image else None
8695
negative_prompt = "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"
8796

8897
save_dir = os.path.join("result", args.outdir)
8998
os.makedirs(save_dir, exist_ok=True)
9099
local_rank = 0
91100
if args.use_usp:
92-
assert not args.prompt_enhancer, "`--prompt_enhancer` is not allowed if using `--use_usp`. We recommend running the skyreels_v2_infer/pipelines/prompt_enhancer.py script first to generate enhanced prompt before enabling the `--use_usp` parameter."
101+
assert (
102+
not args.prompt_enhancer
103+
), "`--prompt_enhancer` is not allowed if using `--use_usp`. We recommend running the skyreels_v2_infer/pipelines/prompt_enhancer.py script first to generate enhanced prompt before enabling the `--use_usp` parameter."
93104
from xfuser.core.distributed import initialize_model_parallel, init_distributed_environment
94105
import torch.distributed as dist
95106

@@ -127,16 +138,23 @@
127138

128139
if args.causal_attention:
129140
pipe.transformer.set_ar_attention(args.causal_block_size)
130-
141+
131142
if args.teacache:
132143
if args.ar_step > 0:
133-
num_steps = args.inference_steps + (((args.base_num_frames - 1)//4 + 1) // args.causal_block_size - 1) * args.ar_step
134-
print('num_steps:', num_steps)
144+
num_steps = (
145+
args.inference_steps
146+
+ (((args.base_num_frames - 1) // 4 + 1) // args.causal_block_size - 1) * args.ar_step
147+
)
148+
print("num_steps:", num_steps)
135149
else:
136150
num_steps = args.inference_steps
137-
pipe.transformer.initialize_teacache(enable_teacache=True, num_steps=num_steps,
138-
teacache_thresh=args.teacache_thresh, use_ret_steps=args.use_ret_steps,
139-
ckpt_dir=args.model_id)
151+
pipe.transformer.initialize_teacache(
152+
enable_teacache=True,
153+
num_steps=num_steps,
154+
teacache_thresh=args.teacache_thresh,
155+
use_ret_steps=args.use_ret_steps,
156+
ckpt_dir=args.model_id,
157+
)
140158

141159
print(f"prompt:{prompt_input}")
142160
print(f"guidance_scale:{guidance_scale}")

skyreels_v2_infer/distributed/xdit_context_parallel.py

Lines changed: 82 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import numpy as np
12
import torch
23
import torch.amp as amp
34
from torch.backends.cuda import sdp_kernel
@@ -59,6 +60,17 @@ def rope_apply(x, grid_sizes, freqs):
5960
return torch.stack(output).float()
6061

6162

63+
def broadcast_should_calc(should_calc: bool) -> bool:
64+
import torch.distributed as dist
65+
66+
device = torch.cuda.current_device()
67+
int_should_calc = 1 if should_calc else 0
68+
tensor = torch.tensor([int_should_calc], device=device, dtype=torch.int8)
69+
dist.broadcast(tensor, src=0)
70+
should_calc = tensor.item() == 1
71+
return should_calc
72+
73+
6274
def usp_dit_forward(self, x, t, context, clip_fea=None, y=None, fps=None):
6375
"""
6476
x: A list of videos each with shape [C, T, H, W].
@@ -135,20 +147,84 @@ def usp_dit_forward(self, x, t, context, clip_fea=None, y=None, fps=None):
135147
e0 = torch.chunk(e0, get_sequence_parallel_world_size(), dim=2)[get_sequence_parallel_rank()]
136148
kwargs = dict(e=e0, grid_sizes=grid_sizes, freqs=self.freqs, context=context, block_mask=self.block_mask)
137149

138-
# Context Parallel
139-
x = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()]
150+
if self.enable_teacache:
151+
modulated_inp = e0 if self.use_ref_steps else e
152+
# teacache
153+
if self.cnt % 2 == 0: # even -> conditon
154+
self.is_even = True
155+
if self.cnt < self.ret_steps or self.cnt >= self.cutoff_steps:
156+
should_calc_even = True
157+
self.accumulated_rel_l1_distance_even = 0
158+
else:
159+
rescale_func = np.poly1d(self.coefficients)
160+
self.accumulated_rel_l1_distance_even += rescale_func(
161+
((modulated_inp - self.previous_e0_even).abs().mean() / self.previous_e0_even.abs().mean())
162+
.cpu()
163+
.item()
164+
)
165+
if self.accumulated_rel_l1_distance_even < self.teacache_thresh:
166+
should_calc_even = False
167+
else:
168+
should_calc_even = True
169+
self.accumulated_rel_l1_distance_even = 0
170+
self.previous_e0_even = modulated_inp.clone()
171+
else: # odd -> unconditon
172+
self.is_even = False
173+
if self.cnt < self.ret_steps or self.cnt >= self.cutoff_steps:
174+
should_calc_odd = True
175+
self.accumulated_rel_l1_distance_odd = 0
176+
else:
177+
rescale_func = np.poly1d(self.coefficients)
178+
self.accumulated_rel_l1_distance_odd += rescale_func(
179+
((modulated_inp - self.previous_e0_odd).abs().mean() / self.previous_e0_odd.abs().mean())
180+
.cpu()
181+
.item()
182+
)
183+
if self.accumulated_rel_l1_distance_odd < self.teacache_thresh:
184+
should_calc_odd = False
185+
else:
186+
should_calc_odd = True
187+
self.accumulated_rel_l1_distance_odd = 0
188+
self.previous_e0_odd = modulated_inp.clone()
140189

141-
for block in self.blocks:
142-
x = block(x, **kwargs)
190+
x = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()]
191+
if self.enable_teacache:
192+
if self.is_even:
193+
should_calc_even = broadcast_should_calc(should_calc_even)
194+
if not should_calc_even:
195+
x += self.previous_residual_even
196+
else:
197+
ori_x = x.clone()
198+
for block in self.blocks:
199+
x = block(x, **kwargs)
200+
ori_x.mul_(-1)
201+
ori_x.add_(x)
202+
self.previous_residual_even = ori_x
203+
else:
204+
should_calc_odd = broadcast_should_calc(should_calc_odd)
205+
if not should_calc_odd:
206+
x += self.previous_residual_odd
207+
else:
208+
ori_x = x.clone()
209+
for block in self.blocks:
210+
x = block(x, **kwargs)
211+
ori_x.mul_(-1)
212+
ori_x.add_(x)
213+
self.previous_residual_odd = ori_x
214+
self.cnt += 1
215+
if self.cnt >= self.num_steps:
216+
self.cnt = 0
217+
else:
218+
# Context Parallel
219+
for block in self.blocks:
220+
x = block(x, **kwargs)
143221

144222
# head
145223
if e.ndim == 3:
146224
e = torch.chunk(e, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()]
147225
x = self.head(x, e)
148-
149226
# Context Parallel
150227
x = get_sp_group().all_gather(x, dim=1)
151-
152228
# unpatchify
153229
x = self.unpatchify(x, grid_sizes)
154230
return x.float()

0 commit comments

Comments
 (0)