Skip to content

Commit 67be511

Browse files
CaitinZhaozhaoting
andauthored
[cogvideox] add more transformer config (mindspore-lab#925)
* [cogvideox] add more transformer config * add i2v transformer config support * add performence * update readme --------- Co-authored-by: zhaoting <zhaoting23@huawei.com>
1 parent 0fb111c commit 67be511

File tree

7 files changed

+174
-19
lines changed

7 files changed

+174
-19
lines changed

examples/diffusers/cogvideox_factory/README.md

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -356,8 +356,38 @@ NODE_RANK="0"
356356
done
357357
```
358358

359-
要了解不同参数的含义,你可以查看 [args](./scripts/args.py) 文件,或者使用 `--help` 运行训练脚本。
359+
> [!TIP]
360+
> 如果想修改transformer的模型结构,可以设置`--transformer_config`。比如修改成30B的模型,可以设置`--transformer_config=configs/cogvideox1.5_30B.yaml`;
361+
> 当配置了`transformer_config`,可以配置`--transformer_ckpt_path`加载checkpoint权重。
362+
363+
要了解更多参数的含义,你可以查看 [args](./scripts/args.py) 文件,或者使用 `--help` 运行训练脚本。
364+
365+
## 性能数据
366+
367+
### 训练
368+
369+
| model | cards | DP | SP | zero | vae cache | video shape | precision | jit level | s/step | memory usage |
370+
|:-----------------:|:-----:|:--:|:--:|:-----:|:---------:|:-----------:|:---------:|:---------:|:------:|:------------:|
371+
| CogvideoX 1.5 T2V 5B | 8 | 8 | 1 | zero3 | ON | 1x77x768x1360 | bf16 | O1 | 39.23 | 35.6 GB |
372+
| CogvideoX 1.5 T2V 5B | 8 | 4 | 2 | zero3 | ON | 1x77x768x1360 | bf16 | O1 | 20.9 | 19.9 GB |
373+
| CogvideoX 1.5 T2V 5B | 8 | 2 | 4 | zero3 | ON | 1x77x768x1360 | bf16 | O1 | 10.1 | 14.6 GB |
374+
| CogvideoX 1.5 T2V 5B | 8 | 1 | 8 | zero3 | ON | 1x77x768x1360 | bf16 | O1 | 5.16 | 8.2 GB |
375+
| CogvideoX 1.5 T2V 5B | 16 | 2 | 8 | zero3 | ON | 1x77x768x1360 | bf16 | O1 | 5.24 | 6.3 GB |
376+
| CogvideoX 1.5 T2V 5B | 8 | 8 | 1 | zero3 | OFF | 1x77x768x1360 | bf16 | O1 | 49 | 40 GB |
377+
| CogvideoX 1.5 T2V 5B | 8 | 1 | 8 | zero3 | OFF | 1x77x768x1360 | bf16 | O1 | 10.58 | 9.3 GB |
378+
| CogvideoX 1.5 T2V 10B | 8 | 2 | 4 | zero3 | ON | 1x77x768x1360 | bf16 | O1 | 15.2 | 25.6 GB |
379+
| CogvideoX 1.5 T2V 20B | 8 | 2 | 4 | zero3 | ON | 1x77x768x1360 | bf16 | O1 | 20.1 | 35.7 GB |
380+
| CogvideoX 1.5 T2V 30B | 8 | 2 | 4 | zero3 | ON | 1x77x768x1360 | bf16 | O1 | 26.5 | 47.3 GB |
381+
382+
以上数据在Disney数据集,910*上获得。
383+
384+
### 推理
385+
386+
| model | cards | DP | SP | zero | video shape | precision | jit level | s/step | total cost |
387+
|:-----------------:|:-----:|:--:|:--:|:-----:|:-------------:|:---------:|:---------:|:------:|:----------:|
388+
| CogvideoX 1.5 T2V 5B | 8 | 1 | 8 | zero3 | 1x77x768x1360 | bf16 | O1 | 3.21 | ~ 5min |
360389

390+
以上数据在910*上获得。
361391

362392
## 与原仓的差异&功能限制
363393

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
transformer:
2+
"activation_fn": "gelu-approximate"
3+
"attention_bias": True
4+
"attention_head_dim": 96
5+
"dropout": 0.0
6+
"flip_sin_to_cos": True
7+
"freq_shift": 0
8+
"in_channels": 16
9+
"max_text_seq_length": 226
10+
"norm_elementwise_affine": True
11+
"norm_eps": 1e-05
12+
"num_attention_heads": 48
13+
"num_layers": 48
14+
"out_channels": 16
15+
"patch_bias": False
16+
"patch_size": 2
17+
"patch_size_t": 2
18+
"sample_frames": 81
19+
"sample_height": 96
20+
"sample_width": 170
21+
"spatial_interpolation_scale": 1.875
22+
"temporal_compression_ratio": 4
23+
"temporal_interpolation_scale": 1.0
24+
"text_embed_dim": 4096
25+
"time_embed_dim": 512
26+
"timestep_activation_fn": "silu"
27+
"use_learned_positional_embeddings": False
28+
"use_rotary_positional_embeddings": True
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
transformer:
2+
"activation_fn": "gelu-approximate"
3+
"attention_bias": True
4+
"attention_head_dim": 128
5+
"dropout": 0.0
6+
"flip_sin_to_cos": True
7+
"freq_shift": 0
8+
"in_channels": 16
9+
"max_text_seq_length": 226
10+
"norm_elementwise_affine": True
11+
"norm_eps": 1e-05
12+
"num_attention_heads": 48
13+
"num_layers": 48
14+
"out_channels": 16
15+
"patch_bias": False
16+
"patch_size": 2
17+
"patch_size_t": 2
18+
"sample_frames": 81
19+
"sample_height": 96
20+
"sample_width": 170
21+
"spatial_interpolation_scale": 1.875
22+
"temporal_compression_ratio": 4
23+
"temporal_interpolation_scale": 1.0
24+
"text_embed_dim": 4096
25+
"time_embed_dim": 512
26+
"timestep_activation_fn": "silu"
27+
"use_learned_positional_embeddings": False
28+
"use_rotary_positional_embeddings": True
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
transformer:
2+
"activation_fn": "gelu-approximate"
3+
"attention_bias": True
4+
"attention_head_dim": 128
5+
"dropout": 0.0
6+
"flip_sin_to_cos": True
7+
"freq_shift": 0
8+
"in_channels": 16
9+
"max_text_seq_length": 226
10+
"norm_elementwise_affine": True
11+
"norm_eps": 1e-05
12+
"num_attention_heads": 48
13+
"num_layers": 64
14+
"out_channels": 16
15+
"patch_bias": False
16+
"patch_size": 2
17+
"patch_size_t": 2
18+
"sample_frames": 81
19+
"sample_height": 96
20+
"sample_width": 170
21+
"spatial_interpolation_scale": 1.875
22+
"temporal_compression_ratio": 4
23+
"temporal_interpolation_scale": 1.0
24+
"text_embed_dim": 4096
25+
"time_embed_dim": 512
26+
"timestep_activation_fn": "silu"
27+
"use_learned_positional_embeddings": False
28+
"use_rotary_positional_embeddings": True

examples/diffusers/cogvideox_factory/scripts/args.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,18 @@ def _get_model_args(parser: argparse.ArgumentParser) -> None:
1717
required=True,
1818
help="Path to pretrained model or model identifier from huggingface.co/models.",
1919
)
20+
parser.add_argument(
21+
"--transformer_config",
22+
type=str,
23+
default=None,
24+
help="Config of transformers. If set it, not use the pretrained_model_name_or_path transformer config.",
25+
)
26+
parser.add_argument(
27+
"--transformer_ckpt_path",
28+
type=str,
29+
default=None,
30+
help="Path to the transformer checkpoint. Only effective when set transformer_config.",
31+
)
2032
parser.add_argument(
2133
"--revision",
2234
type=str,

examples/diffusers/cogvideox_factory/scripts/cogvideox_image_to_video_sft.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -171,15 +171,29 @@ def main(args):
171171
# CogVideoX-2b weights are stored in float16
172172
# CogVideoX-5b and CogVideoX-5b-I2V weights are stored in bfloat16
173173
# load_dtype = ms.bfloat16 if "5b" in args.pretrained_model_name_or_path.lower() else ms.float16
174-
transformer = CogVideoXTransformer3DModel_SP.from_pretrained(
175-
args.pretrained_model_name_or_path,
176-
subfolder="transformer",
177-
mindspore_dtype=weight_dtype,
178-
revision=args.revision,
179-
variant=args.variant,
180-
max_text_seq_length=args.max_sequence_length,
181-
enable_sequence_parallelism=enable_sequence_parallelism,
182-
)
174+
if args.transformer_config is None:
175+
transformer = CogVideoXTransformer3DModel_SP.from_pretrained(
176+
args.pretrained_model_name_or_path,
177+
subfolder="transformer",
178+
mindspore_dtype=weight_dtype,
179+
revision=args.revision,
180+
variant=args.variant,
181+
max_text_seq_length=args.max_sequence_length,
182+
enable_sequence_parallelism=enable_sequence_parallelism,
183+
)
184+
elif os.path.exists(args.transformer_config):
185+
with open(args.transformer_config) as f:
186+
config = yaml.safe_load(f)["transformer"]
187+
config["max_text_seq_length"] = args.max_sequence_length
188+
config["enable_sequence_parallelism"] = enable_sequence_parallelism
189+
transformer = CogVideoXTransformer3DModel_SP(**config)
190+
logger.info(f"Build transformer model from {args.transformer_config}")
191+
if os.path.exists(args.transformer_ckpt_path):
192+
ms.load_checkpoint(args.transformer_ckpt_path, transformer)
193+
logger.info(f"Load transformer checkpoint from {args.transformer_ckpt_path}")
194+
195+
else:
196+
raise ValueError(f"transformer_config: {args.transformer_config} is not exist!")
183197
transformer.fa_checkpointing = args.fa_gradient_checkpointing
184198

185199
text_encoder, vae = None, None

examples/diffusers/cogvideox_factory/scripts/cogvideox_text_to_video_sft.py

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -168,15 +168,30 @@ def main(args):
168168
# CogVideoX-2b weights are stored in float16
169169
# CogVideoX-5b and CogVideoX-5b-I2V weights are stored in bfloat16
170170
# load_dtype = ms.bfloat16 if "5b" in args.pretrained_model_name_or_path.lower() else ms.float16
171-
transformer = CogVideoXTransformer3DModel_SP.from_pretrained(
172-
args.pretrained_model_name_or_path,
173-
subfolder="transformer",
174-
mindspore_dtype=weight_dtype,
175-
revision=args.revision,
176-
variant=args.variant,
177-
max_text_seq_length=args.max_sequence_length,
178-
enable_sequence_parallelism=enable_sequence_parallelism,
179-
)
171+
if args.transformer_config is None:
172+
transformer = CogVideoXTransformer3DModel_SP.from_pretrained(
173+
args.pretrained_model_name_or_path,
174+
subfolder="transformer",
175+
mindspore_dtype=weight_dtype,
176+
revision=args.revision,
177+
variant=args.variant,
178+
max_text_seq_length=args.max_sequence_length,
179+
enable_sequence_parallelism=enable_sequence_parallelism,
180+
)
181+
elif os.path.exists(args.transformer_config):
182+
with open(args.transformer_config) as f:
183+
config = yaml.safe_load(f)["transformer"]
184+
config["max_text_seq_length"] = args.max_sequence_length
185+
config["enable_sequence_parallelism"] = enable_sequence_parallelism
186+
transformer = CogVideoXTransformer3DModel_SP(**config)
187+
logger.info(f"Build transformer model from {args.transformer_config}")
188+
if os.path.exists(args.transformer_ckpt_path):
189+
ms.load_checkpoint(args.transformer_ckpt_path, transformer)
190+
logger.info(f"Load transformer checkpoint from {args.transformer_ckpt_path}")
191+
192+
else:
193+
raise ValueError(f"transformer_config: {args.transformer_config} is not exist!")
194+
180195
transformer.fa_checkpointing = args.fa_gradient_checkpointing
181196

182197
text_encoder, vae = None, None

0 commit comments

Comments
 (0)