diff --git a/README.md b/README.md index 5c4be40f0..e370f83d6 100644 --- a/README.md +++ b/README.md @@ -48,6 +48,8 @@ docker pull registry.cn-hangzhou.aliyuncs.com/yongyang/llmcompression:pure-lates ## Latest News +- **May 12, 2025:** 🔥 We now fully support quantization for the **`Wan2.1`** series of video generation models and provide export of truly quantized **INT8/FP8** weights, compatible with the [lightx2v](https://github.com/ModelTC/lightx2v) inference framework. For details, please refer to the [lightx2v documentation](https://llmc-en.readthedocs.io/en/latest/backend/lightx2v.html). + - **Feb 7, 2025:** 🔥 We now fully support quantization of large-scale **`MOE`** models like **`DeepSeekv3`**, **`DeepSeek-R1`**, and **`DeepSeek-R1-zero`** with **`671B`** parameters. You can now directly load FP8 weights without any extra conversion. AWQ and RTN quantization can run on a single 80GB GPU, and we also support the export of true quantized **INT4/INT8** weights. - **Nov 20, 2024:** 🔥 We now fully support the quantization of ✨`DeepSeekv2(2.5)` and other `MOE` models, as well as ✨`Qwen2VL`, `Llama3.2`, and other `VLM` models. Supported quantization methods include ✅integer quantization, ✅floating-point quantization, and advanced algorithms like ✅AWQ, ✅GPTQ, ✅SmoothQuant, and ✅Quarot. diff --git a/README_ja.md b/README_ja.md index 85df12326..c716ec945 100644 --- a/README_ja.md +++ b/README_ja.md @@ -48,7 +48,9 @@ docker pull registry.cn-hangzhou.aliyuncs.com/yongyang/llmcompression:pure-lates ## 最新情報 -- V 🔥 私たちは現在、671Bパラメータを持つ大規模な **`MOE`** モデル、例えば **`DeepSeekv3`**、**`DeepSeek-R1`**、および **`DeepSeek-R1-zero`** の量子化を完全にサポートしています。今すぐFP8ウェイトを追加の変換なしで直接読み込むことができます。AWQおよびRTN量子化は、1枚の80GB GPUで実行でき、さらに、真の量子化された **INT4/INT8** ウェイトのエクスポートにも対応しています。 +- **2025年5月12日:** 🔥 **`Wan2.1`** シリーズのビデオ生成モデルの量子化を完全にサポートし、実際に量子化された **INT8/FP8** 重みのエクスポートにも対応しました。これらは [lightx2v](https://github.com/ModelTC/lightx2v) 推論フレームワークと互換性があります。詳細は [lightx2v ドキュメント](https://llmc-en.readthedocs.io/en/latest/backend/lightx2v.html) をご参照ください。 + +- **2025年2月7日:** 🔥 私たちは現在、671Bパラメータを持つ大規模な **`MOE`** モデル、例えば **`DeepSeekv3`**、**`DeepSeek-R1`**、および **`DeepSeek-R1-zero`** の量子化を完全にサポートしています。今すぐFP8ウェイトを追加の変換なしで直接読み込むことができます。AWQおよびRTN量子化は、1枚の80GB GPUで実行でき、さらに、真の量子化された **INT4/INT8** ウェイトのエクスポートにも対応しています。 - **2024年11月20日:** 🔥 私たちは現在、✨`DeepSeekv2(2.5)`などの`MOE`モデルおよび✨`Qwen2VL`、`Llama3.2`などの`VLM`モデルの量子化を完全にサポートしています。対応する量子化手法には、✅整数量子化、✅浮動小数点量子化、さらに✅AWQ、✅GPTQ、✅SmoothQuant、✅Quarotといった高度なアルゴリズムが含まれます。 diff --git a/README_zh.md b/README_zh.md index 9adcae0a2..af4682769 100644 --- a/README_zh.md +++ b/README_zh.md @@ -48,6 +48,8 @@ docker pull registry.cn-hangzhou.aliyuncs.com/yongyang/llmcompression:pure-lates ## 最新消息 +- **2025年5月12日:** 🔥 我们现已全面支持 **`Wan2.1`** 系列视频生成模型的量化,并支持导出真实量化的 **INT8/FP8** 权重,兼容 [lightx2v](https://github.com/ModelTC/lightx2v) 推理框架。详情请参考 [lightx2v 使用文档](https://llmc-zhcn.readthedocs.io/en/latest/backend/lightx2v.html)。 + - **2025年2月7日:** 🔥 我们现已全面支持 **`DeepSeekv3`**、**`DeepSeek-R1`** 和 **`DeepSeek-R1-zero`** 等 671B 大规模 **`MOE`** 模型的量化。 您可以直接加载 `FP8` 权重,无需额外转换,使用单张 80G 显存的 GPU 即可运行 `AWQ` 和 `RTN` 量化,同时还支持导出真实量化的 **INT4/INT8** 权重 - **2024年11月20日:** 🔥 我们现已全面支持✨`DeepSeekv2(2.5)`等`MOE`模型以及✨`Qwen2VL`、`Llama3.2`等`VLM`模型的量化。支持的量化方案包括✅整型量化、✅浮点量化,以及✅AWQ、✅GPTQ、✅SmoothQuant 和 ✅Quarot 等先进算法。 diff --git a/assets/wan_i2v/calib/astronaut.jpg b/assets/wan_i2v/calib/astronaut.jpg new file mode 100644 index 000000000..b2c8d3aa4 Binary files /dev/null and b/assets/wan_i2v/calib/astronaut.jpg differ diff --git a/assets/wan_i2v/calib/samples.json b/assets/wan_i2v/calib/samples.json new file mode 100755 index 000000000..b810dfd76 --- /dev/null +++ b/assets/wan_i2v/calib/samples.json @@ -0,0 +1,7 @@ +[ + { + "image": "astronaut.jpg", + "prompt": "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot.", + "negative_prompt": "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards" + } +] diff --git a/assets/wan_i2v/eval/astronaut.jpg b/assets/wan_i2v/eval/astronaut.jpg new file mode 100644 index 000000000..b2c8d3aa4 Binary files /dev/null and b/assets/wan_i2v/eval/astronaut.jpg differ diff --git a/assets/wan_i2v/eval/samples.json b/assets/wan_i2v/eval/samples.json new file mode 100755 index 000000000..b810dfd76 --- /dev/null +++ b/assets/wan_i2v/eval/samples.json @@ -0,0 +1,7 @@ +[ + { + "image": "astronaut.jpg", + "prompt": "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot.", + "negative_prompt": "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards" + } +] diff --git a/assets/wan_t2v/calib/samples.json b/assets/wan_t2v/calib/samples.json old mode 100644 new mode 100755 diff --git a/assets/wan_t2v/eval/samples.json b/assets/wan_t2v/eval/samples.json old mode 100644 new mode 100755 diff --git a/configs/quantization/video_gen/wan_i2v/awq_w_a.yaml b/configs/quantization/video_gen/wan_i2v/awq_w_a.yaml new file mode 100755 index 000000000..680fab43b --- /dev/null +++ b/configs/quantization/video_gen/wan_i2v/awq_w_a.yaml @@ -0,0 +1,49 @@ +base: + seed: &seed 42 +model: + type: WanI2V + path: /path/to/model + torch_dtype: auto +calib: + name: i2v + download: False + path: ../assets/wan_i2v/calib/ + sample_steps: 40 + bs: 1 + target_height: 480 + target_width: 832 + num_frames: 81 + guidance_scale: 5.0 + seed: *seed +eval: + eval_pos: [fake_quant] + type: video_gen + name: i2v + download: False + path: ../assets/wan_i2v/eval/ + bs: 1 + target_height: 480 + target_width: 832 + num_frames: 81 + guidance_scale: 5.0 + output_video_path: ./output_videos_awq/ +quant: + video_gen: + method: Awq + weight: + bit: 8 + symmetric: True + granularity: per_channel + group_size: -1 + act: + bit: 8 + symmetric: True + granularity: per_token + special: + trans: True + trans_version: v2 + weight_clip: False + clip_sym: True +save: + save_lightx2v: True + save_path: /path/to/x2v/ diff --git a/configs/quantization/video_gen/wan_i2v/rtn_w_a.yaml b/configs/quantization/video_gen/wan_i2v/rtn_w_a.yaml new file mode 100755 index 000000000..15ff69a45 --- /dev/null +++ b/configs/quantization/video_gen/wan_i2v/rtn_w_a.yaml @@ -0,0 +1,32 @@ +base: + seed: &seed 42 +model: + type: WanI2V + path: /path/to/model + torch_dtype: auto +eval: + eval_pos: [fake_quant] + type: video_gen + name: i2v + download: False + path: ../assets/wan_i2v/eval/ + bs: 1 + target_height: 480 + target_width: 832 + num_frames: 81 + guidance_scale: 5.0 + output_video_path: ./output_videos_rtn/ +quant: + video_gen: + method: RTN + weight: + bit: 8 + symmetric: True + granularity: per_channel + act: + bit: 8 + symmetric: True + granularity: per_token +save: + save_lightx2v: True + save_path: /path/to/x2v/ diff --git a/configs/quantization/video_gen/wan_i2v/rtn_w_a_lora.yaml b/configs/quantization/video_gen/wan_i2v/rtn_w_a_lora.yaml new file mode 100755 index 000000000..39925db7e --- /dev/null +++ b/configs/quantization/video_gen/wan_i2v/rtn_w_a_lora.yaml @@ -0,0 +1,33 @@ +base: + seed: &seed 42 +model: + type: WanI2V + path: /path/to/model + lora_path: /path/to/lora_weights + torch_dtype: auto +eval: + eval_pos: [fake_quant] + type: video_gen + name: i2v + download: False + path: ../assets/wan_i2v/eval/ + bs: 1 + target_height: 480 + target_width: 832 + num_frames: 81 + guidance_scale: 5.0 + output_video_path: ./output_videos_rtn_lora/ +quant: + video_gen: + method: RTN + weight: + bit: 8 + symmetric: True + granularity: per_channel + act: + bit: 8 + symmetric: True + granularity: per_token +save: + save_lightx2v: True + save_path: /path/to/x2v/ diff --git a/configs/quantization/video_gen/wan_i2v/smoothquant_w_a.yaml b/configs/quantization/video_gen/wan_i2v/smoothquant_w_a.yaml new file mode 100755 index 000000000..e68dea80e --- /dev/null +++ b/configs/quantization/video_gen/wan_i2v/smoothquant_w_a.yaml @@ -0,0 +1,45 @@ +base: + seed: &seed 42 +model: + type: WanI2V + path: /path/to/model + torch_dtype: auto +calib: + name: i2v + download: False + path: ../assets/wan_i2v/calib/ + sample_steps: 40 + bs: 1 + target_height: 480 + target_width: 832 + num_frames: 81 + guidance_scale: 5.0 + seed: *seed +eval: + eval_pos: [fake_quant] + type: video_gen + name: i2v + download: False + path: ../assets/wan_i2v/eval/ + bs: 1 + target_height: 480 + target_width: 832 + num_frames: 81 + guidance_scale: 5.0 + output_video_path: ./output_videos_sq/ +quant: + video_gen: + method: SmoothQuant + weight: + bit: 8 + symmetric: True + granularity: per_channel + act: + bit: 8 + symmetric: True + granularity: per_token + special: + alpha: 0.75 +save: + save_lightx2v: True + save_path: /path/to/x2v/ diff --git a/configs/quantization/video_gen/wan_i2v/smoothquant_w_a_fp8.yaml b/configs/quantization/video_gen/wan_i2v/smoothquant_w_a_fp8.yaml new file mode 100755 index 000000000..e1cdc989a --- /dev/null +++ b/configs/quantization/video_gen/wan_i2v/smoothquant_w_a_fp8.yaml @@ -0,0 +1,49 @@ +base: + seed: &seed 42 +model: + type: WanI2V + path: /path/to/model + torch_dtype: auto +calib: + name: i2v + download: False + path: ../assets/wan_i2v/calib/ + sample_steps: 40 + bs: 1 + target_height: 480 + target_width: 832 + num_frames: 81 + guidance_scale: 5.0 + seed: *seed +eval: + eval_pos: [fake_quant] + type: video_gen + name: i2v + download: False + path: ../assets/wan_i2v/eval/ + bs: 1 + target_height: 480 + target_width: 832 + num_frames: 81 + guidance_scale: 5.0 + output_video_path: ./output_videos_sq/ +quant: + video_gen: + method: SmoothQuant + weight: + quant_type: float-quant + bit: e4m3 + symmetric: True + granularity: per_channel + use_qtorch: True + act: + quant_type: float-quant + bit: e4m3 + symmetric: True + granularity: per_token + use_qtorch: True + special: + alpha: 0.75 +save: + save_lightx2v: True + save_path: /path/to/x2v/ diff --git a/configs/quantization/video_gen/wan_i2v/smoothquant_w_a_int8_lora.yaml b/configs/quantization/video_gen/wan_i2v/smoothquant_w_a_int8_lora.yaml new file mode 100755 index 000000000..6df416f77 --- /dev/null +++ b/configs/quantization/video_gen/wan_i2v/smoothquant_w_a_int8_lora.yaml @@ -0,0 +1,46 @@ +base: + seed: &seed 42 +model: + type: WanI2V + path: /path/to/model + lora_path: /path/to/lora_weights + torch_dtype: auto +calib: + name: i2v + download: False + path: ../assets/wan_i2v/calib/ + sample_steps: 40 + bs: 1 + target_height: 480 + target_width: 832 + num_frames: 81 + guidance_scale: 5.0 + seed: *seed +eval: + eval_pos: [fake_quant] + type: video_gen + name: i2v + download: False + path: ../assets/wan_i2v/eval/ + bs: 1 + target_height: 480 + target_width: 832 + num_frames: 81 + guidance_scale: 5.0 + output_video_path: ./output_videos_sq/ +quant: + video_gen: + method: SmoothQuant + weight: + bit: 8 + symmetric: True + granularity: per_channel + act: + bit: 8 + symmetric: True + granularity: per_token + special: + alpha: 0.75 +save: + save_lightx2v: True + save_path: /path/to/x2v/ diff --git a/configs/quantization/video_gen/wan_t2v/awq_w_a.yaml b/configs/quantization/video_gen/wan_t2v/awq_w_a.yaml old mode 100644 new mode 100755 index 142297432..14d05479d --- a/configs/quantization/video_gen/wan_t2v/awq_w_a.yaml +++ b/configs/quantization/video_gen/wan_t2v/awq_w_a.yaml @@ -5,7 +5,7 @@ model: path: /path/to/wan_t2v torch_dtype: auto calib: - name: custom_t2v + name: t2v download: False path: ../assets/wan_t2v/calib/ sample_steps: 20 @@ -18,7 +18,7 @@ calib: eval: eval_pos: [transformed, fake_quant] type: video_gen - name: custom_t2v + name: t2v download: False path: ../assets/wan_t2v/calib/ bs: 1 @@ -45,6 +45,5 @@ quant: weight_clip: True clip_sym: True save: - save_trans: False - save_fake: False - save_path: /path/to/save/ + save_lightx2v: True + save_path: /path/to/x2v/ diff --git a/configs/quantization/video_gen/wan_t2v/rtn_w_a.yaml b/configs/quantization/video_gen/wan_t2v/rtn_w_a.yaml old mode 100644 new mode 100755 index 85ef32edd..b6a53b0e0 --- a/configs/quantization/video_gen/wan_t2v/rtn_w_a.yaml +++ b/configs/quantization/video_gen/wan_t2v/rtn_w_a.yaml @@ -7,15 +7,15 @@ model: eval: eval_pos: [transformed, fake_quant] type: video_gen - name: custom_t2v + name: t2v download: False - path: /mtc/gushiqiao/llmc_video_new/llmc/assets/wan_t2v/ + path: ../assets/wan_t2v/eval/ bs: 1 target_height: 480 target_width: 832 num_frames: 81 guidance_scale: 5.0 - output_video_path: ./output_videos_sq/ + output_video_path: ./output_videos_rtn/ quant: video_gen: method: RTN @@ -28,6 +28,5 @@ quant: symmetric: True granularity: per_token save: - save_trans: False - save_fake: False - save_path: /path/to/save/ + save_lightx2v: True + save_path: /path/to/x2v/ diff --git a/configs/quantization/video_gen/wan_t2v/smoothquant_w_a.yaml b/configs/quantization/video_gen/wan_t2v/smoothquant_w_a.yaml old mode 100644 new mode 100755 index 5a60faa3e..7d65f31fc --- a/configs/quantization/video_gen/wan_t2v/smoothquant_w_a.yaml +++ b/configs/quantization/video_gen/wan_t2v/smoothquant_w_a.yaml @@ -5,7 +5,7 @@ model: path: /path/to/wan_t2v torch_dtype: auto calib: - name: custom_t2v + name: t2v download: False path: ../assets/wan_t2v/calib/ sample_steps: 20 @@ -18,7 +18,7 @@ calib: eval: eval_pos: [transformed, fake_quant] type: video_gen - name: custom_t2v + name: t2v download: False path: ../assets/wan_t2v/calib/ bs: 1 @@ -41,6 +41,5 @@ quant: special: alpha: 0.7 save: - save_trans: False - save_fake: False - save_path: /path/to/save/ + save_lightx2v: True + save_path: /path/to/x2v/ diff --git a/docs/en/source/backend/lightx2v.md b/docs/en/source/backend/lightx2v.md new file mode 100755 index 000000000..e046f9846 --- /dev/null +++ b/docs/en/source/backend/lightx2v.md @@ -0,0 +1,177 @@ +# lightx2v Quantized Inference + +[lightx2v](https://github.com/ModelTC/lightx2v) is an efficient backend designed specifically to meet the inference demands of video generation models. By optimizing memory management and computational efficiency, it significantly accelerates the inference process. + +**LLMC** supports exporting quantized model formats required by **lightx2v** and offers strong support for multiple quantization algorithms (such as AWQ, GPTQ, SmoothQuant, etc.), maintaining high quantization accuracy while improving inference speed. Combining **LLMC** with **lightx2v** enables accelerated inference and memory optimization without compromising accuracy, making it ideal for scenarios that require efficient video model processing. + +--- + +## 1.1 Environment Setup + +To use **lightx2v** for quantized inference, first install and configure the environment: + +```bash +# Clone the repository and its submodules +git clone https://github.com/ModelTC/lightx2v.git lightx2v && cd lightx2v +git submodule update --init --recursive + +# Create and activate the conda environment +conda create -n lightx2v python=3.11 && conda activate lightx2v +pip install -r requirements.txt + +# Reinstall transformers separately to bypass version conflicts +pip install transformers==4.45.2 + +# Install flash-attention 2 +cd lightx2v/3rd/flash-attention && pip install --no-cache-dir -v -e . + +# Install flash-attention 3 (only if using Hopper architecture) +cd lightx2v/3rd/flash-attention/hopper && pip install --no-cache-dir -v -e . +``` + +--- + +## 1.2 Quantization Formats + +**lightx2v** supports several fixed-point quantization formats: + +- **W8A8**: int8 for weights and activations. +- **FP8 (E4M3)**: float8 for weights and activations. +- **Weight per-channel quantization**. +- **Activation per-token dynamic quantization** for improved precision. +- **Symmetric quantization** for both weights and activations (uses only scale). + +When using **LLMC** to quantize models, ensure the bit-width of weights and activations matches supported **lightx2v** formats. + +--- + +## 1.3 Quantizing Models with LLMC + +### 1.3.1 Calibration Data + +For example, for the Wan2.1 model on the I2V task, a calibration dataset is provided in the [directory](https://github.com/ModelTC/llmc/tree/main/assets/wan_i2v/calib). Users can add more samples as needed. + +### 1.3.2 Choosing Quantization Algorithm + +#### **W8A8** + +We recommend using **SmoothQuant** for W8A8 settings. +Refer to the SmoothQuant W8A8 [configuration file](https://github.com/ModelTC/llmc/tree/main/configs/quantization/video_gen/wan_i2v/smoothquant_w_a.yaml): + +```yaml +quant: + video_gen: + method: SmoothQuant + weight: + bit: 8 + symmetric: True + granularity: per_channel + act: + bit: 8 + symmetric: True + granularity: per_token + special: + alpha: 0.75 +``` + +If SmoothQuant does not meet the precision requirement, use **AWQ** for better accuracy. See the corresponding [configuration](https://github.com/ModelTC/llmc/tree/main/configs/quantization/video_gen/wan_i2v/awq_w_a.yaml). + +#### **FP8-Dynamic** + +LLMC supports FP8 quantization with per-channel weights and per-token dynamic activations. SmoothQuant is again recommended. See the SmoothQuant FP8 [configuration](https://github.com/ModelTC/llmc/tree/main/configs/quantization/backend/lightx2v/fp8/awq_fp8.yml): + +```yaml +quant: + video_gen: + method: SmoothQuant + weight: + quant_type: float-quant + bit: e4m3 + symmetric: True + granularity: per_channel + use_qtorch: True + act: + quant_type: float-quant + bit: e4m3 + symmetric: True + granularity: per_token + use_qtorch: True + special: + alpha: 0.75 +``` + +Ensure `quant_type` is set to `float-quant` and `use_qtorch` to `True`, as **LLMC** uses [QPyTorch](https://github.com/Tiiiger/QPyTorch) for float quantization. + +Install QPyTorch with: + +```bash +pip install qtorch +``` + +### 1.3.3 Exporting the Quantized Model + +```yaml +save: + save_lightx2v: True + save_path: /path/to/save_for_lightx2v/ +``` + +Set `save_lightx2v` to `True`. LLMC will export weights as `torch.int8` or `torch.float8_e4m3fn` for direct loading in **lightx2v**, along with quantization parameters. + +### 1.3.4 Running LLMC + +Edit the config path in the run script and execute: + +```bash +# scripts/run_llmc.sh +llmc=llmc_path +export PYTHONPATH=$llmc:$PYTHONPATH + +task_name=sq_for_lightx2v +config=${llmc}/configs/quantization/video_gen/wan_i2v/smoothquant_w_a.yaml +``` + +After LLMC completes, the quantized model is saved to `save.save_path`. + +### 1.3.5 Evaluation + +For the I2V task with the Wan2.1 model, an evaluation dataset is provided [here](https://github.com/ModelTC/llmc/tree/main/assets/wan_i2v/eval). Set the following in the config file: + +```yaml +eval: + eval_pos: [fake_quant] + type: video_gen + name: i2v + download: False + path: ../assets/wan_i2v/eval/ + bs: 1 + target_height: 480 + target_width: 832 + num_frames: 81 + guidance_scale: 5.0 + output_video_path: ./output_videos_sq/ +``` + +LLMC will generate evaluation videos using the pseudo-quantized model. + +--- + +## 1.4 Inference with lightx2v + +### 1.4.1 Weight Structure Conversion + +After LLMC exports the model, convert its structure to match **lightx2v** requirements using the [conversion script](https://github.com/ModelTC/lightx2v/blob/main/examples/diffusers/converter.py): + +```bash +python converter.py -s /path/to/save_for_lightx2v/ -o /path/to/output/ -d backward +``` + +The converted model will be saved under `/path/to/output/`. + +### 1.4.2 Offline Inference + +Edit the [inference script](https://github.com/ModelTC/lightx2v/blob/main/scripts/run_wan_i2v_advanced_ptq.sh), set `model_path` to `/path/to/output/` and `lightx2v_path` to your local lightx2v path, then run: + +```bash +bash run_wan_i2v_advanced_ptq.sh +``` diff --git a/docs/zh_cn/source/backend/lightx2v.md b/docs/zh_cn/source/backend/lightx2v.md new file mode 100755 index 000000000..9aa3be442 --- /dev/null +++ b/docs/zh_cn/source/backend/lightx2v.md @@ -0,0 +1,177 @@ + +# lightx2v 量化推理 + +[lightx2v](https://github.com/ModelTC/lightx2v) 是一个专为满足视频生成模型推理需求设计的高效后端。它通过优化内存管理和计算效率,能够显著加速推理过程。 + +**LLMC** 支持导出 lightx2v 所需的量化模型格式,并通过其对多种量化算法的强大支持(如 AWQ、GPTQ、SmoothQuant 等),能够在保证推理速度的同时保持较高的量化精度。将 **LLMC** 与 **lightx2v** 结合使用,可以在不牺牲精度的前提下实现推理加速和内存优化,非常适合需要高效处理视频生成模型的应用场景。 + +--- + +## 1.1 环境准备 + +要使用 **lightx2v** 进行量化推理,首先需要安装并配置相关环境: + +```bash +# 克隆仓库及其子模块 +git clone https://github.com/ModelTC/lightx2v.git lightx2v && cd lightx2v +git submodule update --init --recursive + +# 创建并激活 conda 环境 +conda create -n lightx2v python=3.11 && conda activate lightx2v +pip install -r requirements.txt + +# 为避免版本冲突,单独安装 transformers +pip install transformers==4.45.2 + +# 安装 flash-attention 2 +cd lightx2v/3rd/flash-attention && pip install --no-cache-dir -v -e . + +# 安装 flash-attention 3(仅在 Hopper 架构下) +cd lightx2v/3rd/flash-attention/hopper && pip install --no-cache-dir -v -e . +``` + +--- + +## 1.2 量化格式 + +**lightx2v** 支持以下几种常见的定点量化格式: + +- **W8A8**:权重和激活均为 int8; +- **FP8 (E4M3)**:权重和激活均为 float8; +- **权重 per-channel 量化**; +- **激活 per-token 动态量化**,进一步提升精度; +- **对称量化**(仅使用 scale 参数)。 + +使用 **LLMC** 进行模型量化时,必须确保权重和激活的比特数符合 **lightx2v** 所支持的格式。 + +--- + +## 1.3 使用 LLMC 进行模型量化 + +### 1.3.1 校准数据 + +以 Wan2.1 模型在 I2V 任务为例,校准数据示例可在[此目录](https://github.com/ModelTC/llmc/tree/main/assets/wan_i2v/calib)中找到,用户可根据需求添加更多数据。 + +### 1.3.2 量化算法选择 + +#### **W8A8** + +推荐使用 **SmoothQuant** 算法,配置参考如下 [配置文件](https://github.com/ModelTC/llmc/tree/main/configs/quantization/video_gen/wan_i2v/smoothquant_w_a.yaml): + +```yaml +quant: + video_gen: + method: SmoothQuant + weight: + bit: 8 + symmetric: True + granularity: per_channel + act: + bit: 8 + symmetric: True + granularity: per_token + special: + alpha: 0.75 +``` + +如果 SmoothQuant 无法满足精度需求,可以尝试使用 **AWQ**,相关配置请参考 [AWQ 配置文件](https://github.com/ModelTC/llmc/tree/main/configs/quantization/video_gen/wan_i2v/awq_w_a.yaml)。 + +#### **FP8 动态量化** + +对于 FP8 格式,LLMC 支持权重 per-channel、激活 per-token 动态量化。推荐仍使用 **SmoothQuant**,参考配置如下: + +```yaml +quant: + video_gen: + method: SmoothQuant + weight: + quant_type: float-quant + bit: e4m3 + symmetric: True + granularity: per_channel + use_qtorch: True + act: + quant_type: float-quant + bit: e4m3 + symmetric: True + granularity: per_token + use_qtorch: True + special: + alpha: 0.75 +``` + +请确保将 `quant_type` 设置为 `float-quant`,并将 `use_qtorch` 设置为 `True`,因为 LLMC 的浮点量化依赖于 [QPyTorch](https://github.com/Tiiiger/QPyTorch)。 + +安装 QPyTorch: + +```bash +pip install qtorch +``` + +### 1.3.3 导出真实量化模型 + +```yaml +save: + save_lightx2v: True + save_path: /path/to/save_for_lightx2v/ +``` + +务必将 `save_lightx2v` 设置为 `True`。LLMC 会将权重以 `torch.int8` 或 `torch.float8_e4m3fn` 形式导出,供 lightx2v 直接使用,并附带相应的量化参数。 + +### 1.3.4 运行 LLMC + +编辑运行脚本中的配置路径: + +```bash +# scripts/run_llmc.sh +llmc=llmc_path +export PYTHONPATH=$llmc:$PYTHONPATH + +task_name=sq_for_lightx2v +config=${llmc}/configs/quantization/video_gen/wan_i2v/smoothquant_w_a.yaml +``` + +运行完成后,真实量化模型会保存在 `save.save_path` 中。 + +### 1.3.5 模型评估 + +以 Wan2.1 在 I2V 任务为例,测试数据在[此目录](https://github.com/ModelTC/llmc/tree/main/assets/wan_i2v/eval),配置参考如下: + +```yaml +eval: + eval_pos: [fake_quant] + type: video_gen + name: i2v + download: False + path: ../assets/wan_i2v/eval/ + bs: 1 + target_height: 480 + target_width: 832 + num_frames: 81 + guidance_scale: 5.0 + output_video_path: ./output_videos_sq/ +``` + +LLMC 会生成使用伪量化模型生成的视频结果。 + +--- + +## 1.4 使用 lightx2v 进行模型推理 + +### 1.4.1 权重结构转换 + +LLMC 导出后,需将模型结构转换为 lightx2v 支持的格式,可使用 [转换脚本](https://github.com/ModelTC/lightx2v/blob/main/examples/diffusers/converter.py): + +```bash +python converter.py -s /path/to/save_for_lightx2v/ -o /path/to/output/ -d backward +``` + +转换后的模型将保存在 `/path/to/output/`。 + +### 1.4.2 离线推理 + +编辑 [推理脚本](https://github.com/ModelTC/lightx2v/blob/main/scripts/run_wan_i2v_advanced_ptq.sh),设置 `model_path` 为 `/path/to/output/`,`lightx2v_path` 为本地路径,然后运行: + +```bash +bash run_wan_i2v_advanced_ptq.sh +``` diff --git a/llmc/__main__.py b/llmc/__main__.py index 33194b67f..98db4ca26 100755 --- a/llmc/__main__.py +++ b/llmc/__main__.py @@ -121,39 +121,43 @@ def main(config): if config.save.get('save_vllm', False): deploy_all_modality(blockwise_opts, 'vllm_quant') - if config.save.get('save_lightllm', False): + elif config.save.get('save_lightllm', False): deploy_all_modality(blockwise_opts, 'lightllm_quant') - if config.save.get('save_sgl', False): + elif config.save.get('save_sgl', False): deploy_all_modality(blockwise_opts, 'sgl_quant') blockwise_opt.save_model(save_quant_path) update_vllm_quant_config(blockwise_opt.model, config, save_quant_path) - if 'save' in config and config.save.get('save_autoawq', False): - for modality_config in modality_configs: - assert ( - modality_config.weight.bit in [4] and 'act' not in modality_config - ), 'AutoAWQ supports only 4-bit weight-only quantization.' - assert ( - not modality_config.weight.symmetric - ), 'Only asymmetric quant is supported.' - - deploy_all_modality(blockwise_opts, 'autoawq_quant') - blockwise_opt.save_model(save_quant_path) - update_autoawq_quant_config(config, save_quant_path) - - if 'save' in config and config.save.get('save_mlcllm', False): - for modality_config in modality_configs: - assert ( - modality_config.weight.bit in [4] and 'act' not in modality_config - ), 'MlcLLM supports only 4-bit weight-only quantization.' - assert ( - not modality_config.weight.symmetric - ), 'Only asymmetric quant is supported.' - - deploy_all_modality(blockwise_opts, 'mlcllm_quant') - blockwise_opt.save_model(save_quant_path) - update_autoawq_quant_config(config, save_quant_path) + elif config.save.get('save_autoawq', False): + for modality_config in modality_configs: + assert ( + modality_config.weight.bit in [4] and 'act' not in modality_config + ), 'AutoAWQ supports only 4-bit weight-only quantization.' + assert ( + not modality_config.weight.symmetric + ), 'Only asymmetric quant is supported.' + + deploy_all_modality(blockwise_opts, 'autoawq_quant') + blockwise_opt.save_model(save_quant_path) + update_autoawq_quant_config(config, save_quant_path) + + elif config.save.get('save_mlcllm', False): + for modality_config in modality_configs: + assert ( + modality_config.weight.bit in [4] and 'act' not in modality_config + ), 'MlcLLM supports only 4-bit weight-only quantization.' + assert ( + not modality_config.weight.symmetric + ), 'Only asymmetric quant is supported.' + + deploy_all_modality(blockwise_opts, 'mlcllm_quant') + blockwise_opt.save_model(save_quant_path) + update_autoawq_quant_config(config, save_quant_path) + + elif config.save.get('save_lightx2v', False): + deploy_all_modality(blockwise_opts, 'lightx2v_quant') + blockwise_opt.save_model(save_quant_path) if 'opencompass' in config: assert config.save.get('save_trans', False) @@ -240,6 +244,11 @@ def main(config): config.save.save_path, 'mlcllm_quant_model' ) mkdirs(save_quant_path) + if config.save.get('save_lightx2v', False): + save_quant_path = os.path.join( + config.save.save_path, 'lightx2v_quant_model' + ) + mkdirs(save_quant_path) if config.save.get('save_fake', False): save_fake_path = os.path.join(config.save.save_path, 'fake_quant_model') mkdirs(save_fake_path) diff --git a/llmc/compression/quantization/auto_clip.py b/llmc/compression/quantization/auto_clip.py index 047d117b8..62425745e 100755 --- a/llmc/compression/quantization/auto_clip.py +++ b/llmc/compression/quantization/auto_clip.py @@ -10,10 +10,13 @@ if is_fp8_supported_gpu(): from .kernel import weight_cast_to_bf16, weight_cast_to_fp8 - logger.info('import kernel successful.') + logger.info('Successfully imported Triton kernel.') else: from .quant import weight_cast_to_bf16, weight_cast_to_fp8 - logger.info('import quant successful.') + logger.info( + 'Triton kernel not available: non-Hopper GPU detected.\n' + 'Using LLMC Quantizer implementation instead.' + ) class AutoClipper: diff --git a/llmc/compression/quantization/awq.py b/llmc/compression/quantization/awq.py index 4bf4e8868..6fd9b87f6 100755 --- a/llmc/compression/quantization/awq.py +++ b/llmc/compression/quantization/awq.py @@ -13,10 +13,11 @@ if is_fp8_supported_gpu(): from .kernel import weight_cast_to_bf16, weight_cast_to_fp8 - logger.info('import kernel successful.') + logger.info('Successfully imported Triton kernel.') else: from .quant import weight_cast_to_bf16, weight_cast_to_fp8 - logger.info('import quant successful.') + logger.info('Triton kernel not available (non-Hopper GPU detected). \ + Falling back to LLMC Quantizer implementation.') from .module_utils import (_LLMC_LINEAR_TYPES_, _LLMC_LN_TYPES_, _TRANSFORMERS_LINEAR_TYPES_, diff --git a/llmc/compression/quantization/base_blockwise_quantization.py b/llmc/compression/quantization/base_blockwise_quantization.py index 6d3ca4c1f..33927f35d 100755 --- a/llmc/compression/quantization/base_blockwise_quantization.py +++ b/llmc/compression/quantization/base_blockwise_quantization.py @@ -21,10 +21,13 @@ if is_fp8_supported_gpu(): from .kernel import weight_cast_to_bf16, weight_cast_to_fp8 - logger.info('import kernel successful.') + logger.info('Successfully imported Triton kernel.') else: from .quant import weight_cast_to_bf16, weight_cast_to_fp8 - logger.info('import quant successful.') + logger.info( + 'Triton kernel not available: non-Hopper GPU detected.\n' + 'Using LLMC Quantizer implementation instead.' + ) from .hadamard_utils import apply_exact_had_to_linear, get_hadK from .module_utils import (_LLMC_LINEAR_TYPES_, _LLMC_LN_TYPES_, diff --git a/llmc/compression/quantization/module_utils.py b/llmc/compression/quantization/module_utils.py index 959fd961b..4251819ea 100755 --- a/llmc/compression/quantization/module_utils.py +++ b/llmc/compression/quantization/module_utils.py @@ -14,10 +14,15 @@ if is_fp8_supported_gpu(): from .kernel import act_quant, fp8_gemm, weight_cast_to_bf16 USE_FP8GEMM_TRITON_KERNEL = True - logger.info('import kernel successful.') + logger.info('Successfully imported Triton kernel.') else: USE_FP8GEMM_TRITON_KERNEL = False from .quant import weight_cast_to_bf16 + logger.info( + 'Triton kernel not available: non-Hopper GPU detected.\n' + 'Using LLMC Quantizer implementation instead.' + ) + try: import fast_hadamard_transform @@ -58,13 +63,13 @@ class LlmcWanTransformerBlock(nn.Module): def __init__(self, module): super().__init__() - self.norm1 = FakeAffineLayerNorm(module.norm1, module.scale_shift_table.shape[-1]) + self.affine_norm1 = FakeAffineLayerNorm(module.norm1, module.scale_shift_table.shape[-1]) self.attn1 = module.attn1 self.attn2 = module.attn2 self.norm2 = module.norm2 - self.norm3 = FakeAffineLayerNorm(module.norm1, module.scale_shift_table.shape[-1]) + self.affine_norm3 = FakeAffineLayerNorm(module.norm1, module.scale_shift_table.shape[-1]) self.ffn = module.ffn self.scale_shift_table = module.scale_shift_table @@ -80,11 +85,11 @@ def forward( ).chunk(6, dim=1) # 1. Self-attention - norm1_weight = (1 + scale_msa) * self.norm1.weight - norm1_bias = shift_msa * self.norm1.bias + norm1_weight = (1 + scale_msa) * self.affine_norm1.weight + norm1_bias = shift_msa * self.affine_norm1.bias norm_hidden_states = ( - self.norm1(hidden_states.float()) * norm1_weight + norm1_bias + self.affine_norm1(hidden_states.float()) * norm1_weight + norm1_bias ).type_as(hidden_states) attn_output = self.attn1( hidden_states=norm_hidden_states, rotary_emb=rotary_emb @@ -102,11 +107,11 @@ def forward( hidden_states = hidden_states + attn_output # 3. Feed-forward - norm3_weight = (1 + c_scale_msa) * self.norm3.weight - norm3_bias = c_shift_msa * self.norm3.bias + norm3_weight = (1 + c_scale_msa) * self.affine_norm3.weight + norm3_bias = c_shift_msa * self.affine_norm3.bias norm_hidden_states = ( - self.norm3(hidden_states.float()) * norm3_weight + norm3_bias + self.affine_norm3(hidden_states.float()) * norm3_weight + norm3_bias ).type_as(hidden_states) ff_output = self.ffn(norm_hidden_states) hidden_states = ( @@ -872,8 +877,8 @@ def __repr__(self): class LightllmRealQuantLinear(VllmRealQuantLinear): - def __init__(self, weight, bias, scales, input_scale, need_pack): - super().__init__(weight, bias, scales, input_scale, need_pack) + def __init__(self, weight, bias, scales, input_scale, need_pack, scales_name): + super().__init__(weight, bias, scales, input_scale, need_pack, scales_name) def __repr__(self): return ( @@ -890,9 +895,28 @@ def __repr__(self): ) +class Lightx2vRealQuantLinear(VllmRealQuantLinear): + def __init__(self, weight, bias, scales, input_scale, need_pack, scales_name): + super().__init__(weight, bias, scales, input_scale, need_pack, scales_name) + + def __repr__(self): + return ( + 'Lightx2vRealQuantLinear(' + + f'in_features={self.in_features}, ' + + f'out_features={self.out_features}, ' + + f'bias={self.bias is not None}, ' + + f'weight_shape={self.weight_shape}, ' + + f'weight_dtype={self.weight_dtype}, ' + + f'scales_shape={self.scales_shape}, ' + + f'scales_dtype={self.scales_dtype}, ' + + f'zeros_shape={self.zeros_shape}, ' + + f'zeros_dtype={self.zeros_dtype})' + ) + + class SglRealQuantLinear(VllmRealQuantLinear): - def __init__(self, weight, bias, scales, input_scale, need_pack): - super().__init__(weight, bias, scales, input_scale, need_pack) + def __init__(self, weight, bias, scales, input_scale, need_pack, scales_name): + super().__init__(weight, bias, scales, input_scale, need_pack, scales_name) def __repr__(self): return ( @@ -1110,4 +1134,5 @@ def __repr__(self): 'sgl_quant': SglRealQuantLinear, 'autoawq_quant': AutoawqRealQuantLinear, 'mlcllm_quant': MlcllmRealQuantLinear, + 'lightx2v_quant': Lightx2vRealQuantLinear, } diff --git a/llmc/compression/quantization/osplus.py b/llmc/compression/quantization/osplus.py index 283008c82..d5efb1cb5 100755 --- a/llmc/compression/quantization/osplus.py +++ b/llmc/compression/quantization/osplus.py @@ -17,7 +17,13 @@ if is_fp8_supported_gpu(): from .kernel import weight_cast_to_bf16, weight_cast_to_fp8 - logger.info('import kernel successful.') + logger.info('Successfully imported Triton kernel.') +else: + from .quant import weight_cast_to_bf16, weight_cast_to_fp8 + logger.info( + 'Triton kernel not available: non-Hopper GPU detected.\n' + 'Using LLMC Quantizer implementation instead.' + ) @ALGO_REGISTRY diff --git a/llmc/compression/quantization/quarot.py b/llmc/compression/quantization/quarot.py old mode 100644 new mode 100755 diff --git a/llmc/compression/token_reduction/__init__.py b/llmc/compression/token_reduction/__init__.py old mode 100644 new mode 100755 diff --git a/llmc/compression/token_reduction/sparsevlm.py b/llmc/compression/token_reduction/sparsevlm.py old mode 100644 new mode 100755 diff --git a/llmc/compression/token_reduction/utils.py b/llmc/compression/token_reduction/utils.py old mode 100644 new mode 100755 diff --git a/llmc/compression/token_reduction/visionzip.py b/llmc/compression/token_reduction/visionzip.py old mode 100644 new mode 100755 diff --git a/llmc/data/dataset/base_dataset.py b/llmc/data/dataset/base_dataset.py index 01e16ca3f..55c36f84a 100755 --- a/llmc/data/dataset/base_dataset.py +++ b/llmc/data/dataset/base_dataset.py @@ -27,7 +27,7 @@ def __init__(self, tokenizer, calib_cfg, batch_process=None): self.apply_chat_template = calib_cfg.get('apply_chat_template', False) self.n_samples = calib_cfg.get('seq_len', None) self.calib_bs = calib_cfg['bs'] - if self.calib_dataset_name == 'custom_t2v': + if self.calib_dataset_name in ['t2v', 'i2v']: assert self.calib_bs == 1 self.seq_len = calib_cfg.get('seq_len', None) self.preproc = calib_cfg.get('preproc', False) @@ -36,14 +36,14 @@ def __init__(self, tokenizer, calib_cfg, batch_process=None): if self.preproc == 'original_txt': assert self.seq_len is None self.seed = calib_cfg['seed'] - self.dataset_key = { + self.calib_dataset_field_map = { 'pileval': 'text', 'c4': 'text', 'wikitext2': 'text', 'ptb': 'sentence', } - if self.calib_dataset_name in self.dataset_key: - self.key = self.dataset_key[self.calib_dataset_name] + if self.calib_dataset_name in self.calib_dataset_field_map: + self.key = self.calib_dataset_field_map[self.calib_dataset_name] self.build_calib_dataset() def build_calib_dataset(self): @@ -77,15 +77,16 @@ def build_calib_dataset(self): 'custom_txt', 'custom_mm', 'images', - 'custom_t2v', + 't2v', + 'i2v', ]: - self.calib_dataset = self.get_cutomdata(self.calib_dataset_path) + self.calib_dataset = self.get_custom_dataset(self.calib_dataset_path) else: self.calib_dataset = load_from_disk(self.calib_dataset_path) def get_calib_model_inputs(self, samples): if not self.padding: - if self.calib_dataset_name in ['custom_t2v']: + if self.calib_dataset_name in ['t2v', 'i2v']: calib_model_inputs = samples elif self.calib_dataset_name == 'images': calib_model_inputs = self.get_batch_process(samples) @@ -182,8 +183,8 @@ def get_calib_dataset(self): padding_mask = None return calib_model_inputs, padding_mask - def get_cutomdata(self, custom_dataset): - audio_img_qa_json = os.path.join(custom_dataset, 'samples.json') + def get_custom_dataset(self, custom_dataset_path): + audio_img_qa_json = os.path.join(custom_dataset_path, 'samples.json') fp = open(audio_img_qa_json) custom_data_samples = json.load(fp) for idx in range(len(custom_data_samples)): @@ -191,11 +192,11 @@ def get_cutomdata(self, custom_dataset): if isinstance(custom_data_samples[idx]['audio'], list): for audio_idx in range(len(custom_data_samples[idx]['audio'])): custom_data_samples[idx]['audio'][audio_idx] = os.path.join( - custom_dataset, custom_data_samples[idx]['audio'][audio_idx] + custom_dataset_path, custom_data_samples[idx]['audio'][audio_idx] ) else: custom_data_samples[idx]['audio'] = os.path.join( - custom_dataset, custom_data_samples[idx]['audio'] + custom_dataset_path, custom_data_samples[idx]['audio'] ) else: custom_data_samples[idx]['audio'] = None @@ -203,11 +204,11 @@ def get_cutomdata(self, custom_dataset): if isinstance(custom_data_samples[idx]['image'], list): for img_idx in range(len(custom_data_samples[idx]['image'])): custom_data_samples[idx]['image'][img_idx] = os.path.join( - custom_dataset, custom_data_samples[idx]['image'][img_idx] + custom_dataset_path, custom_data_samples[idx]['image'][img_idx] ) else: custom_data_samples[idx]['image'] = os.path.join( - custom_dataset, custom_data_samples[idx]['image'] + custom_dataset_path, custom_data_samples[idx]['image'] ) else: custom_data_samples[idx]['image'] = None diff --git a/llmc/eval/eval_base.py b/llmc/eval/eval_base.py index 10f77070a..7916b2059 100755 --- a/llmc/eval/eval_base.py +++ b/llmc/eval/eval_base.py @@ -19,9 +19,9 @@ def __init__(self, model, config): self.eval_cfg = config.eval self.model_type = config.model.type logger.info(f'eval_cfg : {self.eval_cfg}') - self.dataset = self.eval_cfg['name'] + self.eval_dataset_name = self.eval_cfg['name'] self.dataset_type = self.eval_cfg.get('type', 'ppl') - assert self.dataset in [ + assert self.eval_dataset_name in [ 'wikitext2', 'c4', 'ptb', @@ -30,7 +30,8 @@ def __init__(self, model, config): 'mme', 'custom_ppl', 'custom_gen', - 'custom_t2v' + 't2v', + 'i2v', ], f'Not support {self.dataset} dataset now.' self.seq_len = self.eval_cfg.get('seq_len', None) self.num_samples = self.eval_cfg.get('num_samples', None) @@ -46,15 +47,15 @@ def __init__(self, model, config): @torch.no_grad() def build_data(self): # load data - if self.dataset == 'human_eval': + if self.eval_dataset_name == 'human_eval': testenc = read_problems() else: if self.download: - if self.dataset == 'wikitext2': + if self.eval_dataset_name == 'wikitext2': testdata = load_dataset( 'wikitext', 'wikitext-2-raw-v1', split='test' ) - elif self.dataset == 'c4': + elif self.eval_dataset_name == 'c4': testdata = load_dataset( 'allenai/c4', data_files={ @@ -62,12 +63,12 @@ def build_data(self): }, split='validation', ) - elif self.dataset == 'ptb': + elif self.eval_dataset_name == 'ptb': testdata = load_dataset( 'ptb_text_only', 'penn_treebank', split='test' ) else: - if self.dataset in ['custom_gen', 'custom_ppl', 'custom_t2v']: + if self.eval_dataset_name in ['custom_gen', 'custom_ppl', 't2v', 'i2v']: testdata = self.get_cutomdata(self.eval_dataset_path) else: assert self.eval_dataset_path, 'Please set path in eval_cfg.' @@ -75,27 +76,27 @@ def build_data(self): self.testdata = testdata # encode data if self.dataset_type == 'decode_ppl': - assert self.dataset == 'wikitext2' + assert self.eval_dataset_name == 'wikitext2' testenc = testdata['text'] - elif self.dataset == 'wikitext2': + elif self.eval_dataset_name == 'wikitext2': testenc = self.tokenizer( '\n\n'.join(testdata['text']), return_tensors='pt' ) - elif self.dataset == 'c4': + elif self.eval_dataset_name == 'c4': testenc = self.tokenizer( ' '.join(testdata[:1100]['text']), return_tensors='pt' ) testenc.input_ids = testenc.input_ids[:, : (256 * self.seq_len)] - elif self.dataset == 'ptb': + elif self.eval_dataset_name == 'ptb': testenc = self.tokenizer( ' '.join(testdata['sentence']), return_tensors='pt' ) - elif self.dataset == 'custom_ppl': + elif self.eval_dataset_name == 'custom_ppl': testenc = self.tokenizer( '\n'.join([data['question'] + data['answer'] if 'answer' in data else data['question'] for data in testdata]), # noqa return_tensors='pt', ) - elif self.dataset == 'custom_gen': + elif self.eval_dataset_name == 'custom_gen': testenc = [] if self.eval_dataset_bs < 0: testenc.append( @@ -126,7 +127,7 @@ def build_data(self): apply_chat_template=self.apply_chat_template ) ) - elif self.dataset == 'custom_t2v': + elif self.eval_dataset_name in ['t2v', 'i2v']: testenc = self.testdata return testenc diff --git a/llmc/eval/eval_video_generate.py b/llmc/eval/eval_video_generate.py old mode 100644 new mode 100755 index 8bbac352d..0f99ff6c9 --- a/llmc/eval/eval_video_generate.py +++ b/llmc/eval/eval_video_generate.py @@ -1,8 +1,9 @@ import gc import os +import numpy as np import torch -from diffusers.utils import export_to_video +from diffusers.utils import export_to_video, load_image from loguru import logger from llmc.utils import seed_all @@ -40,8 +41,18 @@ def eval(self, model_llmc, eval_pos): torch.cuda.empty_cache() return eval_res + def pre_process(self, model, image_path): + image = load_image(image_path) + max_area = self.target_height * self.target_width + aspect_ratio = image.height / image.width + mod_value = model.Pipeline.vae_scale_factor_spatial * model.model.config.patch_size[1] + height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value + width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value + image = image.resize((width, height)) + return image, width, height + @torch.no_grad() - def eval_func(self, model, testenc, bs, eval_pos): + def t2v_eval(self, model, testenc, bs, eval_pos): assert bs == 1, 'Only support eval bs=1' for i, data in enumerate(testenc): @@ -60,3 +71,40 @@ def eval_func(self, model, testenc, bs, eval_pos): ) return None + + @torch.no_grad() + def i2v_eval(self, model, testenc, bs, eval_pos): + for i, data in enumerate(testenc): + image, width, height = self.pre_process(model, data['image']) + + output = model.Pipeline( + image=image, + prompt=data['prompt'], + negative_prompt=data['negative_prompt'], + height=height, + width=width, + num_frames=self.num_frames, + guidance_scale=self.guidance_scale, + ).frames[0] + + export_to_video( + output, + os.path.join(self.output_video_path, f'{eval_pos}_output_{i}.mp4'), + fps=self.fps, + ) + + return None + + @torch.no_grad() + def eval_func(self, model, testenc, bs, eval_pos): + assert bs == 1, 'Evaluation only supports batch size = 1.' + assert self.model_type in ['WanT2V', 'WanI2V'], ( + f"Unsupported model type '{self.model_type}'.\n" + 'Only Wan2.1 video generation models (WanT2V, WanI2V) are supported.' + ) + if self.eval_dataset_name == 't2v': + return self.t2v_eval(model, testenc, bs, eval_pos) + elif self.eval_dataset_name == 'i2v': + return self.i2v_eval(model, testenc, bs, eval_pos) + else: + raise Exception(f'Unsupported eval dataset: {self.eval_dataset_name}') diff --git a/llmc/eval/eval_vqa.py b/llmc/eval/eval_vqa.py index 350ac61c0..f4cb9e0f5 100755 --- a/llmc/eval/eval_vqa.py +++ b/llmc/eval/eval_vqa.py @@ -18,9 +18,9 @@ class VQAEval: def __init__(self, config): self.eval_config = config.eval self.model_path = config.model.path - self.dataset = self.eval_config['name'] + self.eval_dataset_name = self.eval_config['name'] if not isinstance(self.dataset, list): - self.dataset = [self.dataset, ] + self.eval_dataset_name = [self.dataset, ] self.eval_dataset_path = self.eval_config['path'] self.eval_bs = self.eval_config['bs'] diff --git a/llmc/models/__init__.py b/llmc/models/__init__.py index 841ae81a2..d57751a5e 100755 --- a/llmc/models/__init__.py +++ b/llmc/models/__init__.py @@ -28,4 +28,5 @@ from .starcoder import Starcoder from .vila import Vila from .vit import Vit +from .wan_i2v import WanI2V from .wan_t2v import WanT2V diff --git a/llmc/models/qwen2vl.py b/llmc/models/qwen2vl.py old mode 100644 new mode 100755 diff --git a/llmc/models/wan_i2v.py b/llmc/models/wan_i2v.py new file mode 100755 index 000000000..09dd02e31 --- /dev/null +++ b/llmc/models/wan_i2v.py @@ -0,0 +1,131 @@ +import inspect +import json +import os +from collections import defaultdict + +import numpy as np +import torch +import torch.nn as nn +from diffusers import AutoencoderKLWan, WanImageToVideoPipeline +from diffusers.utils import load_image +from loguru import logger +from PIL import Image +from safetensors import safe_open +from transformers import CLIPVisionModel + +from llmc.compression.quantization.module_utils import LlmcWanTransformerBlock +from llmc.utils import seed_all +from llmc.utils.registry_factory import MODEL_REGISTRY + +from .wan_t2v import WanT2V + + +@MODEL_REGISTRY +class WanI2V(WanT2V): + def __init__(self, config, device_map=None, use_cache=False): + super().__init__(config, device_map, use_cache) + + def build_model(self): + image_encoder = CLIPVisionModel.from_pretrained( + self.model_path, subfolder='image_encoder', torch_dtype=torch.float32 + ) + vae = AutoencoderKLWan.from_pretrained( + self.model_path, subfolder='vae', torch_dtype=torch.float32 + ) + self.Pipeline = WanImageToVideoPipeline.from_pretrained( + self.model_path, + vae=vae, + image_encoder=image_encoder, + torch_dtype=torch.bfloat16, + ) + self.find_llmc_model() + self.find_blocks() + for block_idx, block in enumerate(self.blocks): + new_block = LlmcWanTransformerBlock.new(block) + self.Pipeline.transformer.blocks[block_idx] = new_block + self.lora_path = self.config.model.get('lora_path', None) + if self.lora_path is not None: + logger.info('Loading lora weights...') + self.load_lora_weights() + + logger.info(f'self.model : {self.model}') + + def pre_process(self, image_path): + image = load_image(image_path) + max_area = self.target_height * self.target_width + aspect_ratio = image.height / image.width + mod_value = ( + self.Pipeline.vae_scale_factor_spatial * self.model.config.patch_size[1] + ) + height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value + width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value + image = image.resize((width, height)) + return image, width, height + + @torch.no_grad() + def collect_first_block_input(self, calib_data, padding_mask=None): + first_block_input = defaultdict(list) + Catcher = self.get_catcher(first_block_input) + self.blocks[0] = Catcher(self.blocks[0]) + self.Pipeline.to('cuda') + for data in calib_data: + self.blocks[0].step = 0 + try: + image, width, height = self.pre_process(data['image']) + self.Pipeline( + image=image, + prompt=data['prompt'], + negative_prompt=data['negative_prompt'], + height=height, + width=width, + num_frames=self.num_frames, + guidance_scale=self.guidance_scale, + ) + except ValueError: + pass + + self.first_block_input = first_block_input + assert len(self.first_block_input['data']) > 0, 'Catch input data failed.' + self.n_samples = len(self.first_block_input['data']) + logger.info(f'Retrieved {self.n_samples} calibration samples for T2V.') + self.blocks[0] = self.blocks[0].module + self.Pipeline.to('cpu') + + def load_lora_weights(self, alpha=1.0): + state_dict = self.model.state_dict() + model_index_file = os.path.join( + self.lora_path, 'diffusion_pytorch_model.safetensors.index.json' + ) + + with open(model_index_file, 'r') as f: + model_index = json.load(f) + + weight_map = model_index['weight_map'] + model_keys = list(state_dict.keys()) + + matched_keys = {} + for model_key in model_keys: + if not model_key.endswith('.weight'): + continue + base_name = model_key.replace('.weight', '') + for lora_key in weight_map: + if base_name in lora_key: + if model_key not in matched_keys: + matched_keys[model_key] = [] + matched_keys[model_key].append(lora_key) + + for weight_name in matched_keys: + lora_A_name, lora_B_name = matched_keys[weight_name] + weight = state_dict[weight_name].cuda() + lora_A_path = os.path.join(self.lora_path, weight_map[lora_A_name]) + with safe_open(lora_A_path, framework='pt', device='cuda') as f: + lora_A_weight = f.get_tensor(lora_A_name) + + lora_B_path = os.path.join(self.lora_path, weight_map[lora_B_name]) + with safe_open(lora_B_path, framework='pt', device='cuda') as f: + lora_B_weight = f.get_tensor(lora_B_name) + + merge_weight = weight + (lora_B_weight @ lora_A_weight) * alpha + state_dict[weight_name] = merge_weight.cpu() + + self.model.load_state_dict(state_dict) diff --git a/llmc/models/wan_t2v.py b/llmc/models/wan_t2v.py old mode 100644 new mode 100755 index e7b89286c..885bccda3 --- a/llmc/models/wan_t2v.py +++ b/llmc/models/wan_t2v.py @@ -112,9 +112,9 @@ def __str__(self): def get_layernorms_in_block(self, block): return { - 'norm1': block.norm1, + 'affine_norm1': block.affine_norm1, 'norm2': block.norm2, - 'norm3': block.norm3, + 'affine_norm3': block.affine_norm3, } def get_subsets_in_block(self, block): @@ -125,7 +125,7 @@ def get_subsets_in_block(self, block): 'attn1.to_k': block.attn1.to_k, 'attn1.to_v': block.attn1.to_v, }, - 'prev_op': [block.norm1], + 'prev_op': [block.affine_norm1], 'input': ['attn1.to_q'], 'inspect': block.attn1, 'has_kwargs': True, @@ -145,7 +145,7 @@ def get_subsets_in_block(self, block): 'layers': { 'ffn.net.0.proj': block.ffn.net[0].proj, }, - 'prev_op': [block.norm3], + 'prev_op': [block.affine_norm3], 'input': ['ffn.net.0.proj'], 'inspect': block.ffn, 'has_kwargs': True, diff --git a/requirements/runtime.txt b/requirements/runtime.txt old mode 100644 new mode 100755 diff --git a/scripts/run_llmc.sh b/scripts/run_llmc.sh old mode 100644 new mode 100755 index 180dcadcc..d90877f69 --- a/scripts/run_llmc.sh +++ b/scripts/run_llmc.sh @@ -43,4 +43,4 @@ ps aux | grep '__main__.py' | grep $task_id | awk '{print $2}' > ${task_name}.pi # You can kill this program by # xargs kill -9 < xxx.pid -# xxx.pid is ${task_name}.pid file +# xxx.pid is ${task_name}.pid file \ No newline at end of file diff --git a/scripts/run_lm_eval.sh b/scripts/run_lm_eval.sh old mode 100644 new mode 100755