Skip to content

Commit 702ca3c

Browse files
authored
Merge pull request #19 from StromNoNo/dev_acceleration
Dev acceleration
2 parents 76a3e71 + d8ff469 commit 702ca3c

File tree

7 files changed

+175
-39
lines changed

7 files changed

+175
-39
lines changed

README.md

Lines changed: 35 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -119,14 +119,41 @@ conda activate worldplay
119119
pip install -r requirements.txt
120120
```
121121

122-
### 2. Install Flash Attention (Optional but Recommended)
123-
Install Flash Attention for faster inference and reduced GPU memory consumption:
124-
```bash
125-
pip install flash-attn --no-build-isolation
126-
```
127-
Detailed instructions: [Flash Attention](https://github.com/Dao-AILab/flash-attention)
128-
129-
### 3. Download All Required Models
122+
### 2. Install Attention Libraries (Optional but Recommended)
123+
* Flash Attention:
124+
Install Flash Attention for faster inference and reduced GPU memory consumption:
125+
```bash
126+
pip install flash-attn --no-build-isolation
127+
```
128+
Detailed instructions: [Flash Attention](https://github.com/Dao-AILab/flash-attention)
129+
130+
131+
* SageAttention:
132+
To enable SageAttention for faster inference, you need to install it by the following command:
133+
```bash
134+
git clone https://github.com/cooper1637/SageAttention.git
135+
cd SageAttention
136+
export EXT_PARALLEL=4 NVCC_APPEND_FLAGS="--threads 8" MAX_JOBS=32 # Optional
137+
python3 setup.py install
138+
```
139+
140+
### 3. Install AngelSlim and DeepGEMM
141+
* AngelSlim:
142+
Install AngelSlim to quantize transformer.
143+
```bash
144+
pip install angelslim==0.2.2
145+
```
146+
147+
* DeepGEMM:
148+
To enable fp8 gemm for transformer, you need to install it by the following command:
149+
```bash
150+
git clone --recursive git@github.com:deepseek-ai/DeepGEMM.git
151+
cd DeepGEMM
152+
./develop.sh
153+
./install.sh
154+
```
155+
156+
### 4. Download All Required Models
130157

131158
We provide a download script that automatically downloads all required models:
132159

@@ -315,7 +342,6 @@ https://github.com/user-attachments/assets/531bf0ad-1fca-4d76-bb65-84701368926d
315342
https://github.com/user-attachments/assets/f165f409-5a74-4e19-a32c-fc98d92259e1
316343

317344
## 📝 TODO
318-
- [ ] Acceleration & Quantization
319345
- [ ] Open-source training code
320346

321347
## 📚 Citation

generate.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,7 @@ def pose_to_input(pose_data, latent_num, tps=False):
264264
rotate_one_label = one_hot_to_one_dimension(rotate_one_hot)
265265
action_one_label = trans_one_label * 9 + rotate_one_label
266266

267-
return torch.tensor(w2c_list), torch.tensor(intrinsic_list), action_one_label
267+
return torch.as_tensor(w2c_list), torch.as_tensor(intrinsic_list), action_one_label
268268

269269
def save_video(video, path):
270270
if video.ndim == 5:
@@ -833,6 +833,38 @@ def main():
833833
'Use --with-ui or --with-ui true/1 to enable, --with-ui false/0 to disable'
834834
)
835835

836+
parser.add_argument(
837+
'--use_sageattn', type=str_to_bool, nargs='?', const=True, default=False,
838+
help='Enable sageattn (default: false). '
839+
'Use --use_sageattn or --use_sageattn true/1 to enable, '
840+
'--use_sageattn false/0 to disable'
841+
)
842+
parser.add_argument(
843+
'--sage_blocks_range', type=str, default="0-53",
844+
help='Sageattn blocks range (e.g., 0-5 or 0,1,2,3,4,5)'
845+
)
846+
parser.add_argument(
847+
'--use_vae_parallel', type=str_to_bool, nargs='?', const=True, default=False,
848+
help='Enable vae parallel (default: false). '
849+
'Use --use_vae_parallel or --use_vae_parallel true/1 to enable, '
850+
'--use_vae_parallel false/0 to disable'
851+
)
852+
# fp8 gemm related
853+
parser.add_argument(
854+
'--use_fp8_gemm', type=str_to_bool, nargs='?', const=True, default=False,
855+
help='Enable fp8 gemm for transformer (default: false). '
856+
'Use --use_fp8_gemm or --use_fp8_gemm true/1 to enable, '
857+
'--use_fp8_gemm false/0 to disable'
858+
)
859+
parser.add_argument(
860+
'--quant_type', type=str, default="fp8-per-block",
861+
help='Quantization type for fp8 gemm (e.g., fp8-per-tensor-weight-only, fp8-per-tensor, fp8-per-block)'
862+
)
863+
parser.add_argument(
864+
'--include_patterns', type=str, default="double_blocks",
865+
help='Include patterns for fp8 gemm (default: double_blocks)'
866+
)
867+
836868
args = parser.parse_args()
837869

838870
assert args.image_path is not None

hyvideo/commons/infer_state.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,22 @@
1515
# See the License for the specific language governing permissions and limitations under the License.
1616

1717
from typing import Optional
18-
from dataclasses import dataclass
18+
from dataclasses import dataclass, field
1919

2020
@dataclass
2121
class InferState:
2222
enable_sageattn: bool = False # whether to use SageAttention
2323
sage_blocks_range: Optional[range] = None # block range to use SageAttention
2424
enable_torch_compile: bool = False # whether to use torch compile
2525

26+
# fp8 gemm related
27+
use_fp8_gemm: bool = False # whether to use fp8 gemm
28+
quant_type: str = "fp8-per-block" # fp8 quantization type
29+
include_patterns: list = field(default_factory=lambda: ["double_blocks"]) # include patterns for fp8 gemm
30+
31+
# vae related
32+
use_vae_parallel: bool = False # whether to use vae parallel
33+
2634
__infer_state = None
2735

2836
def parse_range(value):
@@ -34,13 +42,28 @@ def parse_range(value):
3442

3543
def initialize_infer_state(args):
3644
global __infer_state
37-
sage_blocks_range = None
45+
sage_blocks_range = parse_range(args.sage_blocks_range)
3846
# Map CLI argument use_sageattn to internal enable_sageattn field
39-
use_sageattn = False
47+
use_sageattn = getattr(args, 'use_sageattn', False)
48+
49+
# Parse include_patterns from args
50+
include_patterns = getattr(args, 'include_patterns', "double_blocks")
51+
if isinstance(include_patterns, str):
52+
# Split by comma and strip whitespace
53+
include_patterns = [p.strip() for p in include_patterns.split(',') if p.strip()]
54+
4055
__infer_state = InferState(
4156
enable_sageattn = use_sageattn,
4257
sage_blocks_range = sage_blocks_range,
4358
enable_torch_compile = args.enable_torch_compile,
59+
60+
# fp8 gemm related
61+
use_fp8_gemm = args.use_fp8_gemm,
62+
quant_type = args.quant_type,
63+
include_patterns = include_patterns,
64+
65+
# vae related
66+
use_vae_parallel = args.use_vae_parallel,
4467
)
4568
return __infer_state
4669

hyvideo/models/autoencoders/hunyuanvideo_15_vae_w_cache.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -383,6 +383,28 @@ def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
383383
# compute the shortcut part
384384
shortcut = rearrange(x, "b (r2 r3 c) f h w -> b c f (h r2) (w r3)", r2=2, r3=2)
385385
shortcut = shortcut.repeat_interleave(repeats=self.repeats // 2, dim=1)
386+
elif feat_cache is None and x.shape[2] > 1:
387+
# Multi-frame input without cache: first frame only spatial upsample, rest frames do spatio-temporal upsample
388+
# Separate first frame and remaining frames
389+
h_first = h[:, :, :1, :, :] # first frame
390+
h_rest = h[:, :, 1:, :, :] # remaining frames
391+
x_first = x[:, :, :1, :, :]
392+
x_rest = x[:, :, 1:, :, :]
393+
394+
# First frame: only spatial upsample
395+
h_first = rearrange(h_first, "b (r2 r3 c) f h w -> b c f (h r2) (w r3)", r2=2, r3=2)
396+
h_first = h_first[:, : h_first.shape[1] // 2]
397+
shortcut_first = rearrange(x_first, "b (r2 r3 c) f h w -> b c f (h r2) (w r3)", r2=2, r3=2)
398+
shortcut_first = shortcut_first.repeat_interleave(repeats=self.repeats // 2, dim=1)
399+
out_first = h_first + shortcut_first
400+
401+
# Remaining frames: spatio-temporal upsample
402+
h_rest = rearrange(h_rest, "b (r1 r2 r3 c) f h w -> b c (f r1) (h r2) (w r3)", r1=r1, r2=2, r3=2)
403+
shortcut_rest = rearrange(x_rest, "b (r1 r2 r3 c) f h w -> b c (f r1) (h r2) (w r3)", r1=r1, r2=2, r3=2)
404+
shortcut_rest = shortcut_rest.repeat_interleave(repeats=self.repeats, dim=1)
405+
out_rest = h_rest + shortcut_rest
406+
407+
return torch.cat([out_first, out_rest], dim=2)
386408
else:
387409
h = rearrange(h, "b (r1 r2 r3 c) f h w -> b c (f r1) (h r2) (w r3)", r1=r1, r2=2, r3=2)
388410
# compute the shortcut part
@@ -870,8 +892,9 @@ def tile_parallel_spatial_tiled_decode(self, z: torch.Tensor):
870892
decoded_metas.append(torch.tensor([ri, rj, pad_w, pad_h], device=z.device, dtype=torch.int64))
871893

872894
while len(decoded_tiles) < tiles_per_rank:
895+
T_out = decoded_tiles[0].shape[2] if len(decoded_tiles) > 0 else (T-1)*self.ffactor_temporal+1
873896
zero_tile = torch.zeros(
874-
[1, 3, (T - 1) * self.ffactor_temporal + 1, self.tile_sample_min_size, self.tile_sample_min_size],
897+
[1, 3, T_out, self.tile_sample_min_size, self.tile_sample_min_size],
875898
device=dec.device,
876899
dtype=dec.dtype
877900
)
@@ -891,6 +914,7 @@ def tile_parallel_spatial_tiled_decode(self, z: torch.Tensor):
891914

892915
dist.all_gather(tiles_gather_list, decoded_tiles, group=get_parallel_state().sp_group)
893916
dist.all_gather(metas_gather_list, decoded_metas, group=get_parallel_state().sp_group)
917+
dist.barrier()
894918

895919
if rank != 0:
896920
return torch.empty(0, device=z.device)

hyvideo/models/transformers/modules/attention.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -156,13 +156,20 @@ def shrink_head(encoder_state, dim):
156156
t_kv['k_txt'] = encoder_key
157157
t_kv['v_txt'] = encoder_value
158158

159-
encoder_hidden_states = F.scaled_dot_product_attention(
160-
encoder_query,
161-
encoder_key,
162-
encoder_value,
163-
dropout_p=0.0,
164-
is_causal=False
165-
)
159+
infer_state = get_infer_state()
160+
enable_sageattn = (infer_state.enable_sageattn and
161+
block_idx in infer_state.sage_blocks_range)
162+
if enable_sageattn:
163+
from sageattention import sageattn
164+
encoder_hidden_states = sageattn(encoder_query, encoder_key, encoder_value, tensor_layout="HND", is_causal=False)
165+
else:
166+
encoder_hidden_states = F.scaled_dot_product_attention(
167+
encoder_query,
168+
encoder_key,
169+
encoder_value,
170+
dropout_p=0.0,
171+
is_causal=False
172+
)
166173

167174
# transpose back
168175
encoder_hidden_states = encoder_hidden_states.transpose(1, 2) # [B, S, H, D]
@@ -227,7 +234,14 @@ def sequence_parallel_attention_vision(q, k, v,
227234
key = torch.cat([encoder_key, key], dim=2)
228235
value = torch.cat([encoder_value, value], dim=2)
229236

230-
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
237+
infer_state = get_infer_state()
238+
enable_sageattn = (infer_state.enable_sageattn and
239+
block_idx in infer_state.sage_blocks_range)
240+
if enable_sageattn:
241+
from sageattention import sageattn
242+
hidden_states = sageattn(query, key, value, tensor_layout="HND", is_causal=False)
243+
else:
244+
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
231245

232246
# transpose back
233247
hidden_states = hidden_states.transpose(1, 2) # [B, S, H, D]

hyvideo/pipelines/worldplay_video_pipeline.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@
5252
)
5353
from hyvideo.commons.parallel_states import get_parallel_state
5454

55+
from hyvideo.commons.infer_state import get_infer_state
56+
5557
from hyvideo.models.autoencoders import hunyuanvideo_15_vae_w_cache
5658
from hyvideo.models.text_encoders import PROMPT_TEMPLATE, TextEncoder
5759
from hyvideo.models.text_encoders.byT5 import load_glyph_byT5_v2
@@ -1635,6 +1637,10 @@ def __call__(
16351637
else:
16361638
latents = latents / self.vae.config.scaling_factor
16371639

1640+
if get_infer_state() and get_infer_state().use_vae_parallel:
1641+
self.vae.enable_spatial_tiling()
1642+
self.vae.enable_tile_parallelism()
1643+
16381644

16391645
if return_pre_sr_video or not enable_sr:
16401646
with (torch.autocast(device_type="cuda", dtype=self.vae_dtype, enabled=self.vae_autocast_enabled),
@@ -1767,6 +1773,14 @@ def create_pipeline(cls, pretrained_model_name_or_path, transformer_version, cre
17671773

17681774
transformer = transformer.to(transformer_dtype).to(transformer_init_device)
17691775

1776+
infer_state = get_infer_state()
1777+
if infer_state.use_fp8_gemm:
1778+
from angelslim.compressor.diffusion import DynamicDiTQuantizer
1779+
quant_type = infer_state.quant_type
1780+
include_patterns = infer_state.include_patterns
1781+
quantizer = DynamicDiTQuantizer(quant_type=quant_type, include_patterns=include_patterns)
1782+
quantizer.convert_linear(transformer)
1783+
17701784
vae = hunyuanvideo_15_vae_w_cache.AutoencoderKLConv3D.from_pretrained(
17711785
os.path.join(cached_folder, "vae"),
17721786
torch_dtype=vae_inference_config['dtype']

run.sh

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -46,22 +46,22 @@ ENABLE_SR=false # Enable super resolution. When the NUM_FRAMES == 125, you can s
4646

4747
# inference with autoregressive model
4848
# torchrun --nproc_per_node=$N_INFERENCE_GPU generate.py \
49-
# --prompt "$PROMPT" \
50-
# --image_path $IMAGE_PATH \
51-
# --resolution $RESOLUTION \
52-
# --aspect_ratio $ASPECT_RATIO \
53-
# --video_length $NUM_FRAMES \
54-
# --seed $SEED \
55-
# --rewrite $REWRITE \
56-
# --sr $ENABLE_SR --save_pre_sr_video \
57-
# --pose "$POSE" \
58-
# --output_path $OUTPUT_PATH \
59-
# --model_path $MODEL_PATH \
60-
# --action_ckpt $AR_ACTION_MODEL_PATH \
61-
# --few_step false \
62-
# --width $WIDTH \
63-
# --height $HEIGHT \
64-
# --model_type 'ar'
49+
# --prompt "$PROMPT" \
50+
# --image_path $IMAGE_PATH \
51+
# --resolution $RESOLUTION \
52+
# --aspect_ratio $ASPECT_RATIO \
53+
# --video_length $NUM_FRAMES \
54+
# --seed $SEED \
55+
# --rewrite $REWRITE \
56+
# --sr $ENABLE_SR --save_pre_sr_video \
57+
# --pose "$POSE" \
58+
# --output_path $OUTPUT_PATH \
59+
# --model_path $MODEL_PATH \
60+
# --action_ckpt $AR_ACTION_MODEL_PATH \
61+
# --few_step false \
62+
# --width $WIDTH \
63+
# --height $HEIGHT \
64+
# --model_type 'ar'
6565

6666
# inference with autoregressive distilled model
6767
torchrun --nproc_per_node=$N_INFERENCE_GPU generate.py \
@@ -79,4 +79,7 @@ torchrun --nproc_per_node=$N_INFERENCE_GPU generate.py \
7979
--action_ckpt $AR_DISTILL_ACTION_MODEL_PATH \
8080
--few_step true \
8181
--num_inference_steps 4 \
82-
--model_type 'ar'
82+
--model_type 'ar' \
83+
--use_vae_parallel false \
84+
--use_sageattn false \
85+
--use_fp8_gemm false \

0 commit comments

Comments
 (0)