Skip to content

Commit b8dbfc0

Browse files
authored
[OMNIML-2182]: Add example for multinode calibration using FSDP2 (#432)
Signed-off-by: Suguna Velury <[email protected]>
1 parent eb9e31e commit b8dbfc0

File tree

13 files changed

+1102
-214
lines changed

13 files changed

+1102
-214
lines changed

CHANGELOG.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ Model Optimizer Changelog (Linux)
1212
- Add support for ``nemotron-post-training-dataset-v2`` and ``nemotron-post-training-dataset-v1`` in ``examples/llm_ptq``. Default to a mix of ``cnn_dailymail`` and ``nemotron-post-training-dataset-v2`` (gated dataset accessed using ``HF_TOKEN`` environment variable) if no dataset is specified.
1313
- Allow specifying ``calib_seq`` in ``examples/llm_ptq`` to set the maximum sequence length for calibration.
1414
- Add support for MCore MoE PTQ/QAT/QAD.
15+
- Add support for multi-node PTQ and export with FSDP2 in ``examples/llm_ptq/multinode_ptq.py``. See `examples/llm_ptq/README.md <https://github.com/NVIDIA/TensorRT-Model-Optimizer/tree/main/examples/llm_ptq#multi-node-post-training-quantization-with-fsdp2>`_ for more details.
1516

1617
**Documentation**
1718

examples/llm_ptq/README.md

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,38 @@ with init_quantized_weights(mtq.NVFP4_DEFAULT_CFG):
235235
mtq.calibrate(model, algorithm="max", forward_loop=calibrate_loop)
236236
```
237237

238+
## Multi-Node Post-Training Quantization with FSDP2
239+
240+
ModelOpt enables quantization of LLMs across multiple GPU nodes using various quantization formats. It leverages HuggingFace's Accelerate library and FSDP2 for distributed model sharding and calibration.
241+
242+
### Usage
243+
244+
For distributed execution across multiple nodes, use the `accelerate` library. A template configuration file (`fsdp2.yaml`) is provided and can be customized for user specific requirements.
245+
246+
On each node run the following command:
247+
248+
```bash
249+
accelerate launch --config_file fsdp2.yaml \
250+
--num_machines=<num_nodes> \
251+
--machine_rank=<current_node_rank> \
252+
--main_process_ip=<node0_ip_addr> \
253+
--main_process_port=<port> \
254+
--fsdp_transformer_layer_cls_to_wrap=<decoder_layer_name>
255+
multinode_ptq.py \
256+
--pyt_ckpt_path <path_to_model> \
257+
--qformat <fp8/nvfp4/nvfp4_awq/int8> \
258+
--kv_cache_qformat <fp8/nvfp4/nvfp4_affine/none> \
259+
--batch_size <calib_batch_size> \
260+
--calib_size <num_calib_samples> \
261+
--dataset <dataset> \
262+
--export_path <export_path> \
263+
--trust_remote_code
264+
```
265+
266+
The exported checkpoint can be deployed using TensorRT-LLM/ vLLM/ SGLang. For more details refer to the [deployment section](#deployment) of this document.
267+
268+
> *Performance Note: FSDP2 is designed for training workloads and may result in longer calibration and export times. For faster calibration, maximize the batch size based on available GPU memory and choose the right number of GPUs to avoid unnecessary communication.*
269+
>
238270
## Framework Scripts
239271

240272
### Hugging Face Example [Script](./scripts/huggingface_example.sh)

examples/llm_ptq/example_utils.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16+
import copy
1617
import glob
1718
import os
1819
import shutil
@@ -32,11 +33,66 @@
3233
except ImportError:
3334
snapshot_download = None
3435

36+
import modelopt.torch.quantization as mtq
3537
from modelopt.torch.utils.image_processor import MllamaImageProcessor
3638

3739
SPECULATIVE_MODEL_LIST = ["Eagle", "Medusa"]
3840

3941

42+
def build_quant_cfg(
43+
qformat,
44+
kv_cache_qformat,
45+
awq_block_size,
46+
auto_quantize,
47+
model_type,
48+
quant_cfg_choices,
49+
kv_quant_cfg_choices,
50+
):
51+
quant_cfg = {}
52+
if not auto_quantize:
53+
assert qformat in quant_cfg_choices, (
54+
f"Unsupported quantization format: {qformat} with {kv_cache_qformat} KV cache"
55+
)
56+
57+
quant_cfg = quant_cfg_choices[qformat]
58+
59+
if "awq" in qformat:
60+
quant_cfg = copy.deepcopy(quant_cfg_choices[qformat])
61+
weight_quantizer = quant_cfg["quant_cfg"]["*weight_quantizer"]
62+
if isinstance(weight_quantizer, list):
63+
weight_quantizer = weight_quantizer[0]
64+
# If awq_block_size argument is provided, update weight_quantizer
65+
if awq_block_size:
66+
weight_quantizer["block_sizes"][-1] = awq_block_size
67+
68+
# Coarser optimal scale search seems to resolve the overflow in TRT-LLM for some models
69+
if qformat == "w4a8_awq" and model_type in ["gemma", "mpt"]:
70+
quant_cfg["algorithm"] = {"method": "awq_lite", "alpha_step": 1}
71+
72+
enable_quant_kv_cache = kv_cache_qformat != "none"
73+
print(f"{'Enable' if enable_quant_kv_cache else 'Disable'} KV cache quantization")
74+
75+
# Check if any bmm_quantizer is in the quant_cfg. If so, we need to enable the bmm_quantizer.
76+
if enable_quant_kv_cache:
77+
quant_cfg = apply_kv_cache_quant(
78+
quant_cfg,
79+
getattr(mtq, kv_quant_cfg_choices[kv_cache_qformat])["quant_cfg"],
80+
)
81+
82+
# Gemma 7B has accuracy regression using alpha 1. We set 0.5 instead.
83+
if model_type == "gemma" and "int8_sq" in qformat:
84+
quant_cfg["algorithm"] = {"method": "smoothquant", "alpha": 0.5}
85+
86+
if model_type == "phi4mm":
87+
# Only quantize the language model
88+
quant_cfg["quant_cfg"]["*speech*"] = {"enable": False}
89+
quant_cfg["quant_cfg"]["*audio*"] = {"enable": False}
90+
quant_cfg["quant_cfg"]["*image*"] = {"enable": False}
91+
quant_cfg["quant_cfg"]["*vision*"] = {"enable": False}
92+
93+
return quant_cfg
94+
95+
4096
def is_speculative(hf_config):
4197
"""Check if the model architecture is a speculative model."""
4298
return hf_config.architectures and any(

examples/llm_ptq/fsdp2.yaml

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# =============================================================================
2+
# FSDP Configuration for running LLM PTQ on multinode setup. This file is consumed by examples/llm_ptq/multinode_ptq.py
3+
# =============================================================================
4+
5+
compute_environment: LOCAL_MACHINE
6+
debug: false
7+
distributed_type: FSDP
8+
downcast_bf16: 'no'
9+
enable_cpu_affinity: false
10+
fsdp_config:
11+
fsdp_activation_checkpointing: false
12+
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
13+
fsdp_cpu_ram_efficient_loading: true
14+
fsdp_offload_params: false
15+
fsdp_reshard_after_forward: true
16+
fsdp_state_dict_type: FULL_STATE_DICT
17+
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
18+
fsdp_use_orig_params: true
19+
fsdp_version: 2
20+
machine_rank: 0
21+
main_training_function: main
22+
mixed_precision: 'no'
23+
num_machines: 2
24+
num_processes: 16
25+
rdzv_backend: c10d
26+
same_network: true
27+
tpu_env: []
28+
tpu_use_cluster: false
29+
tpu_use_sudo: false
30+
use_cpu: false

examples/llm_ptq/hf_ptq.py

Lines changed: 10 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
# limitations under the License.
1515

1616
import argparse
17-
import copy
1817
import random
1918
import time
2019
import warnings
@@ -25,6 +24,7 @@
2524
from accelerate.hooks import remove_hook_from_module
2625
from example_utils import (
2726
apply_kv_cache_quant,
27+
build_quant_cfg,
2828
copy_custom_model_files,
2929
get_model,
3030
get_processor,
@@ -448,47 +448,15 @@ def main(args):
448448
include_labels=args.auto_quantize_bits is not None,
449449
)
450450

451-
quant_cfg = {}
452-
if not args.auto_quantize_bits:
453-
assert args.qformat in QUANT_CFG_CHOICES, (
454-
f"Unsupported quantization format: {args.qformat} with {args.kv_cache_qformat} KV cache"
455-
)
456-
457-
quant_cfg = QUANT_CFG_CHOICES[args.qformat]
458-
459-
if "awq" in args.qformat:
460-
quant_cfg = copy.deepcopy(QUANT_CFG_CHOICES[args.qformat])
461-
weight_quantizer = quant_cfg["quant_cfg"]["*weight_quantizer"]
462-
if isinstance(weight_quantizer, list):
463-
weight_quantizer = weight_quantizer[0]
464-
# If awq_block_size argument is provided, update weight_quantizer
465-
if args.awq_block_size:
466-
weight_quantizer["block_sizes"][-1] = args.awq_block_size
467-
468-
# Coarser optimal scale search seems to resolve the overflow in TRT-LLM for some models
469-
if args.qformat == "w4a8_awq" and model_type in ["gemma", "mpt"]:
470-
quant_cfg["algorithm"] = {"method": "awq_lite", "alpha_step": 1}
471-
472-
enable_quant_kv_cache = args.kv_cache_qformat != "none"
473-
print(f"{'Enable' if enable_quant_kv_cache else 'Disable'} KV cache quantization")
474-
475-
# Check if any bmm_quantizer is in the quant_cfg. If so, we need to enable the bmm_quantizer.
476-
if enable_quant_kv_cache:
477-
quant_cfg = apply_kv_cache_quant(
478-
quant_cfg,
479-
getattr(mtq, KV_QUANT_CFG_CHOICES[args.kv_cache_qformat])["quant_cfg"],
480-
)
481-
482-
# Gemma 7B has accuracy regression using alpha 1. We set 0.5 instead.
483-
if model_type == "gemma" and "int8_sq" in args.qformat:
484-
quant_cfg["algorithm"] = {"method": "smoothquant", "alpha": 0.5}
485-
486-
if model_type == "phi4mm":
487-
# Only quantize the language model
488-
quant_cfg["quant_cfg"]["*speech*"] = {"enable": False}
489-
quant_cfg["quant_cfg"]["*audio*"] = {"enable": False}
490-
quant_cfg["quant_cfg"]["*image*"] = {"enable": False}
491-
quant_cfg["quant_cfg"]["*vision*"] = {"enable": False}
451+
quant_cfg = build_quant_cfg(
452+
args.qformat,
453+
args.kv_cache_qformat,
454+
args.awq_block_size,
455+
args.auto_quantize_bits,
456+
model_type,
457+
QUANT_CFG_CHOICES,
458+
KV_QUANT_CFG_CHOICES,
459+
)
492460

493461
if not model_is_already_quantized or calibration_only:
494462
# Only run single sample for preview

0 commit comments

Comments
 (0)