Skip to content

Commit 00f2d1a

Browse files
authored
Merge pull request #1169 from Feng0w0/sample_add
Docs:Supplement NPU training script samples and documentation instruction
2 parents 8cc3bec + 62c3d40 commit 00f2d1a

File tree

13 files changed

+286
-9
lines changed

13 files changed

+286
-9
lines changed

diffsynth/core/device/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
1-
from .npu_compatible_device import parse_device_type, parse_nccl_backend, get_available_device_type
1+
from .npu_compatible_device import parse_device_type, parse_nccl_backend, get_available_device_type, get_device_name
2+
from .npu_compatible_device import IS_NPU_AVAILABLE

diffsynth/core/vram/layers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from typing import Union
33
from .initialization import skip_model_initialization
44
from .disk_map import DiskMap
5-
from ..device import parse_device_type
5+
from ..device import parse_device_type, get_device_name, IS_NPU_AVAILABLE
66

77

88
class AutoTorchModule(torch.nn.Module):
@@ -63,7 +63,7 @@ def cast_to(self, weight, dtype, device):
6363
return r
6464

6565
def check_free_vram(self):
66-
device = self.computation_device if self.computation_device != "npu" else "npu:0"
66+
device = self.computation_device if not IS_NPU_AVAILABLE else get_device_name()
6767
gpu_mem_state = getattr(torch, self.computation_device_type).mem_get_info(device)
6868
used_memory = (gpu_mem_state[1] - gpu_mem_state[0]) / (1024**3)
6969
return used_memory < self.vram_limit

diffsynth/diffusion/base_pipeline.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from ..utils.lora import GeneralLoRALoader
88
from ..models.model_loader import ModelPool
99
from ..utils.controlnet import ControlNetInput
10+
from ..core.device import get_device_name, IS_NPU_AVAILABLE
1011

1112

1213
class PipelineUnit:
@@ -177,7 +178,7 @@ def generate_noise(self, shape, seed=None, rand_device="cpu", rand_torch_dtype=t
177178

178179

179180
def get_vram(self):
180-
device = self.device if self.device != "npu" else "npu:0"
181+
device = self.device if not IS_NPU_AVAILABLE else get_device_name()
181182
return getattr(torch, self.device_type).mem_get_info(device)[1] / (1024 ** 3)
182183

183184
def get_module(self, model, name):

docs/en/Pipeline_Usage/GPU_support.md

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ All sample code provided by this project supports NVIDIA GPUs by default, requir
1313
AMD provides PyTorch packages based on ROCm, so most models can run without code changes. A small number of models may not be compatible due to their reliance on CUDA-specific instructions.
1414

1515
## Ascend NPU
16-
16+
### Inference
1717
When using Ascend NPU, you need to replace `"cuda"` with `"npu"` in your code.
1818

1919
For example, here is the inference code for **Wan2.1-T2V-1.3B**, modified for Ascend NPU:
@@ -22,6 +22,7 @@ For example, here is the inference code for **Wan2.1-T2V-1.3B**, modified for As
2222
import torch
2323
from diffsynth.utils.data import save_video, VideoData
2424
from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig
25+
from diffsynth.core.device.npu_compatible_device import get_device_name
2526

2627
vram_config = {
2728
"offload_dtype": "disk",
@@ -46,7 +47,7 @@ pipe = WanVideoPipeline.from_pretrained(
4647
],
4748
tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"),
4849
- vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 2,
49-
+ vram_limit=torch.npu.mem_get_info("npu:0")[1] / (1024 ** 3) - 2,
50+
+ vram_limit=torch.npu.mem_get_info(get_device_name())[1] / (1024 ** 3) - 2,
5051
)
5152

5253
video = pipe(
@@ -56,3 +57,28 @@ video = pipe(
5657
)
5758
save_video(video, "video.mp4", fps=15, quality=5)
5859
```
60+
61+
### Training
62+
NPU startup script samples have been added for each type of model,the scripts are stored in the `examples/xxx/special/npu_scripts`, for example `examples/wanvideo/model_training/special/npu_scripts/Wan2.2-T2V-A14B-NPU.sh`.
63+
64+
In the NPU training scripts, NPU specific environment variables that can optimize performance have been added, and relevant parameters have been enabled for specific models.
65+
66+
#### Environment variables
67+
```shell
68+
export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True
69+
```
70+
`expandable_segments:<value>`: Enable the memory pool expansion segment function, which is the virtual memory feature.
71+
72+
```shell
73+
export CPU_AFFINITY_CONF=1
74+
```
75+
Set 0 or not set: indicates not enabling the binding function
76+
77+
1: Indicates enabling coarse-grained kernel binding
78+
79+
2: Indicates enabling fine-grained kernel binding
80+
81+
#### Parameters for specific models
82+
| Model | Parameter | Note |
83+
|----------------|---------------------------|-------------------|
84+
| Wan 14B series | --initialize_model_on_cpu | The 14B model needs to be initialized on the CPU |

docs/zh/Pipeline_Usage/GPU_support.md

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
AMD 提供了基于 ROCm 的 torch 包,所以大多数模型无需修改代码即可运行,少数模型由于依赖特定的 cuda 指令无法运行。
1414

1515
## Ascend NPU
16-
16+
### 推理
1717
使用 Ascend NPU 时,需把代码中的 `"cuda"` 改为 `"npu"`
1818

1919
例如,Wan2.1-T2V-1.3B 的推理代码:
@@ -22,6 +22,7 @@ AMD 提供了基于 ROCm 的 torch 包,所以大多数模型无需修改代码
2222
import torch
2323
from diffsynth.utils.data import save_video, VideoData
2424
from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig
25+
from diffsynth.core.device.npu_compatible_device import get_device_name
2526

2627
vram_config = {
2728
"offload_dtype": "disk",
@@ -33,7 +34,7 @@ vram_config = {
3334
+ "preparing_device": "npu",
3435
"computation_dtype": torch.bfloat16,
3536
- "computation_device": "cuda",
36-
+ "preparing_device": "npu",
37+
+ "computation_device": "npu",
3738
}
3839
pipe = WanVideoPipeline.from_pretrained(
3940
torch_dtype=torch.bfloat16,
@@ -46,7 +47,7 @@ pipe = WanVideoPipeline.from_pretrained(
4647
],
4748
tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"),
4849
- vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 2,
49-
+ vram_limit=torch.npu.mem_get_info("npu:0")[1] / (1024 ** 3) - 2,
50+
+ vram_limit=torch.npu.mem_get_info(get_device_name())[1] / (1024 ** 3) - 2,
5051
)
5152

5253
video = pipe(
@@ -56,3 +57,28 @@ video = pipe(
5657
)
5758
save_video(video, "video.mp4", fps=15, quality=5)
5859
```
60+
61+
### 训练
62+
当前已为每类模型添加NPU的启动脚本样例,脚本存放在`examples/xxx/special/npu_scripts`目录下,例如 `examples/wanvideo/model_training/special/npu_scripts/Wan2.2-T2V-A14B-NPU.sh`
63+
64+
在NPU训练脚本中,添加了可以优化性能的NPU特有环境变量,并针对特定模型开启了相关参数。
65+
66+
#### 环境变量
67+
```shell
68+
export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True
69+
```
70+
`expandable_segments:<value>`: 使能内存池扩展段功能,即虚拟内存特征。
71+
72+
```shell
73+
export CPU_AFFINITY_CONF=1
74+
```
75+
设置0或未设置: 表示不启用绑核功能
76+
77+
1: 表示开启粗粒度绑核
78+
79+
2: 表示开启细粒度绑核
80+
81+
#### 特定模型需要开启的参数
82+
| 模型 | 参数 | 备注 |
83+
|-----------|------|-------------------|
84+
| Wan 14B系列 | --initialize_model_on_cpu | 14B模型需要在cpu上进行初始化 |
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True
2+
export CPU_AFFINITY_CONF=1
3+
4+
accelerate launch --config_file examples/flux/model_training/full/accelerate_config_zero2offload.yaml examples/flux/model_training/train.py \
5+
--dataset_base_path data/example_image_dataset \
6+
--dataset_metadata_path data/example_image_dataset/metadata_kontext.csv \
7+
--data_file_keys "image,kontext_images" \
8+
--max_pixels 1048576 \
9+
--dataset_repeat 400 \
10+
--model_id_with_origin_paths "black-forest-labs/FLUX.1-Kontext-dev:flux1-kontext-dev.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/*.safetensors,black-forest-labs/FLUX.1-dev:ae.safetensors" \
11+
--learning_rate 1e-5 \
12+
--num_epochs 1 \
13+
--remove_prefix_in_ckpt "pipe.dit." \
14+
--output_path "./models/train/FLUX.1-Kontext-dev_full" \
15+
--trainable_models "dit" \
16+
--extra_inputs "kontext_images" \
17+
--use_gradient_checkpointing
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True
2+
export CPU_AFFINITY_CONF=1
3+
4+
accelerate launch --config_file examples/flux/model_training/full/accelerate_config_zero2offload.yaml examples/flux/model_training/train.py \
5+
--dataset_base_path data/example_image_dataset \
6+
--dataset_metadata_path data/example_image_dataset/metadata.csv \
7+
--max_pixels 1048576 \
8+
--dataset_repeat 400 \
9+
--model_id_with_origin_paths "black-forest-labs/FLUX.1-dev:flux1-dev.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/*.safetensors,black-forest-labs/FLUX.1-dev:ae.safetensors" \
10+
--learning_rate 1e-5 \
11+
--num_epochs 1 \
12+
--remove_prefix_in_ckpt "pipe.dit." \
13+
--output_path "./models/train/FLUX.1-dev_full" \
14+
--trainable_models "dit" \
15+
--use_gradient_checkpointing
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# Due to memory limitations, split training is required to train the model on NPU
2+
export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True
3+
export CPU_AFFINITY_CONF=1
4+
5+
accelerate launch examples/qwen_image/model_training/train.py \
6+
--dataset_base_path data/example_image_dataset \
7+
--dataset_metadata_path data/example_image_dataset/metadata.csv \
8+
--max_pixels 1048576 \
9+
--dataset_repeat 1 \
10+
--model_id_with_origin_paths "Qwen/Qwen-Image-Edit-2509:text_encoder/model*.safetensors,Qwen/Qwen-Image-Edit-2509:vae/diffusion_pytorch_model.safetensors" \
11+
--learning_rate 1e-4 \
12+
--num_epochs 5 \
13+
--remove_prefix_in_ckpt "pipe.dit." \
14+
--output_path "./models/train/Qwen-Image-Edit-2509-LoRA-splited-cache" \
15+
--lora_base_model "dit" \
16+
--lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1" \
17+
--lora_rank 32 \
18+
--use_gradient_checkpointing \
19+
--dataset_num_workers 8 \
20+
--find_unused_parameters \
21+
--task "sft:data_process"
22+
23+
accelerate launch examples/qwen_image/model_training/train.py \
24+
--dataset_base_path "./models/train/Qwen-Image-Edit-2509-LoRA-splited-cache" \
25+
--max_pixels 1048576 \
26+
--dataset_repeat 50 \
27+
--model_id_with_origin_paths "Qwen/Qwen-Image-Edit-2509:transformer/diffusion_pytorch_model*.safetensors" \
28+
--learning_rate 1e-4 \
29+
--num_epochs 5 \
30+
--remove_prefix_in_ckpt "pipe.dit." \
31+
--output_path "./models/train/Qwen-Image-Edit-2509-LoRA-splited" \
32+
--lora_base_model "dit" \
33+
--lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1" \
34+
--lora_rank 32 \
35+
--use_gradient_checkpointing \
36+
--dataset_num_workers 8 \
37+
--find_unused_parameters \
38+
--task "sft:train"
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# Due to memory limitations, split training is required to train the model on NPU
2+
export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True
3+
export CPU_AFFINITY_CONF=1
4+
5+
accelerate launch examples/qwen_image/model_training/train.py \
6+
--dataset_base_path data/example_image_dataset \
7+
--dataset_metadata_path data/example_image_dataset/metadata.csv \
8+
--max_pixels 1048576 \
9+
--dataset_repeat 1 \
10+
--model_id_with_origin_paths "Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors" \
11+
--learning_rate 1e-4 \
12+
--num_epochs 5 \
13+
--remove_prefix_in_ckpt "pipe.dit." \
14+
--output_path "./models/train/Qwen-Image-LoRA-splited-cache" \
15+
--lora_base_model "dit" \
16+
--lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1" \
17+
--lora_rank 32 \
18+
--use_gradient_checkpointing \
19+
--dataset_num_workers 8 \
20+
--find_unused_parameters \
21+
--task "sft:data_process"
22+
23+
accelerate launch examples/qwen_image/model_training/train.py \
24+
--dataset_base_path "./models/train/Qwen-Image-LoRA-splited-cache" \
25+
--max_pixels 1048576 \
26+
--dataset_repeat 50 \
27+
--model_id_with_origin_paths "Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors" \
28+
--learning_rate 1e-4 \
29+
--num_epochs 5 \
30+
--remove_prefix_in_ckpt "pipe.dit." \
31+
--output_path "./models/train/Qwen-Image-LoRA-splited" \
32+
--lora_base_model "dit" \
33+
--lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1" \
34+
--lora_rank 32 \
35+
--use_gradient_checkpointing \
36+
--dataset_num_workers 8 \
37+
--find_unused_parameters \
38+
--task "sft:train"
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True
2+
export CPU_AFFINITY_CONF=1
3+
4+
accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \
5+
--dataset_base_path data/example_video_dataset \
6+
--dataset_metadata_path data/example_video_dataset/metadata.csv \
7+
--height 480 \
8+
--width 832 \
9+
--dataset_repeat 100 \
10+
--model_id_with_origin_paths "Wan-AI/Wan2.1-T2V-14B:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-T2V-14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-T2V-14B:Wan2.1_VAE.pth" \
11+
--learning_rate 1e-5 \
12+
--num_epochs 2 \
13+
--remove_prefix_in_ckpt "pipe.dit." \
14+
--output_path "./models/train/Wan2.1-T2V-14B_full" \
15+
--trainable_models "dit" \
16+
--initialize_model_on_cpu

0 commit comments

Comments
 (0)