-
Notifications
You must be signed in to change notification settings - Fork 160
Description
Is your feature request related to a problem? Please describe.
I am trying to apply Post-Training Quantization (specifically NVFP4) to the GLM-4.5 model using Tensor Parallelism across multiple nodes. My goal is to create a script similar to the ptq.py example, but for the GLM-4.5 model.
The main challenge is that I'm loading the model using Hugging Face transformers with tp_plan="auto", which seems to handle Tensor Parallelism differently than the DeepSeek example. Instead of using distinct ColumnParallelLinear or RowParallelLinear classes, it uses standard torch.nn.Linear layers and manages the TP logic internally.
This leads to a persistent warning during the modelopt.torch.quantization.quantize call, and I'm unsure if the quantization is being applied correctly for the TP layers.
The Warning:
/opt/conda/lib/python3.11/site-packages/modelopt/torch/opt/dynamic.py:895: UserWarning: Distributed training is initialized but no parallel_state is set for <class 'modelopt.torch.opt.dynamic.QuantLinear'>. Using default parallel_state which has data_parallel_group set to the default process group and tensor_parallel_group is unspecified. If you are using tensor parallelism for this module, you should set the parallel_state in its `_setup` method.
And the script raise error in export_hf_checkpoint
:
[rank0]: Traceback (most recent call last):
[rank0]: File "/data/numa0/TensorRT-Model-Optimizer/examples/llm_ptq/test.py", line 72, in <module>
[rank0]: main()
[rank0]: File "/data/numa0/TensorRT-Model-Optimizer/examples/llm_ptq/test.py", line 43, in main
[rank0]: export_hf_checkpoint(
[rank0]: File "/opt/conda/lib/python3.11/site-packages/modelopt/torch/export/unified_export_hf.py", line 456, in export_hf_checkpoint
[rank0]: raise e
[rank0]: File "/opt/conda/lib/python3.11/site-packages/modelopt/torch/export/unified_export_hf.py", line 424, in export_hf_checkpoint
[rank0]: post_state_dict, hf_quant_config = _export_hf_checkpoint(model, dtype)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/opt/conda/lib/python3.11/site-packages/modelopt/torch/export/unified_export_hf.py", line 367, in _export_hf_checkpoint
[rank0]: "weight_scale", get_weight_scaling_factor(sub_module)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/opt/conda/lib/python3.11/site-packages/modelopt/torch/export/quant_utils.py", line 221, in get_weight_scaling_factor
[rank0]: return DTensor._op_dispatcher.dispatch(
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/opt/conda/lib/python3.11/site-packages/torch/distributed/tensor/_dispatch.py", line 154, in dispatch
[rank0]: self.sharding_propagator.propagate(op_info)
[rank0]: File "/opt/conda/lib/python3.11/site-packages/torch/distributed/tensor/_sharding_prop.py", line 266, in propagate
[rank0]: OutputSharding, self.propagate_op_sharding(op_info.schema)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/opt/conda/lib/python3.11/site-packages/torch/distributed/tensor/_sharding_prop.py", line 45, in __call__
[rank0]: return self.cache(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/opt/conda/lib/python3.11/site-packages/torch/distributed/tensor/_sharding_prop.py", line 279, in propagate_op_sharding_non_cached
[rank0]: out_tensor_meta = self._propagate_tensor_meta_non_cached(op_schema)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/opt/conda/lib/python3.11/site-packages/torch/distributed/tensor/_sharding_prop.py", line 126, in _propagate_tensor_meta_non_cached
[rank0]: fake_out = op_schema.op(*fake_args, **fake_kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/opt/conda/lib/python3.11/site-packages/torch/_ops.py", line 829, in __call__
[rank0]: return self._op(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/opt/conda/lib/python3.11/site-packages/torch/_compile.py", line 53, in inner
[rank0]: return disable_fn(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 929, in _fn
[rank0]: return fn(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^
[rank0]: File "/opt/conda/lib/python3.11/site-packages/torch/distributed/tensor/_api.py", line 350, in __torch_dispatch__
[rank0]: return DTensor._op_dispatcher.dispatch(
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/opt/conda/lib/python3.11/site-packages/torch/distributed/tensor/_dispatch.py", line 151, in dispatch
[rank0]: op_info = self.unwrap_to_op_info(op_call, args, kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/opt/conda/lib/python3.11/site-packages/torch/distributed/tensor/_dispatch.py", line 350, in unwrap_to_op_info
[rank0]: self._try_replicate_spec_for_scalar_tensor(
[rank0]: File "/opt/conda/lib/python3.11/site-packages/torch/distributed/tensor/_dispatch.py", line 452, in _try_replicate_spec_for_scalar_tensor
[rank0]: raise RuntimeError(
[rank0]: RuntimeError: aten.index_put_.default: got mixed torch.Tensor and DTensor, need to convert all torch.Tensor to DTensor before calling distributed operators!
/opt/conda/lib/python3.11/site-packages/modelopt/torch/export/unified_export_hf.py:451: UserWarning: Cannot export model to the model_config. The modelopt-optimized model state_dict (including the quantization factors) is saved to /data/numa0/downloaded_models/GLM-4.5-Air-nvfp4-tp4-test/modelopt_model.pth using torch.save for further inspection.
Describe the solution you'd like
I would like to understand the correct procedure to prepare a model for TP-aware quantization when it's loaded via tp_plan="auto" and doesn't have explicit parallel layer subclasses. My end goal is to successfully run PTQ on GLM-4.5 in a multi-node environment and export a quantized checkpoint, using tp or pp.
Describe alternatives you've considered
I have gone through an iterative process to solve this, but the warning persists.
-
Initial Investigation: I first tried to identify custom parallel layer classes (like ColumnParallelLinear) by inspecting the model structure. However, the inspection revealed that all relevant layers are standard torch.nn.Linear instances. This means the monkey-patching approach from ptq.py (using mtq.register) is not directly applicable.
-
Attempt: Dynamic Attribute Injection Based on the discovery above, I wrote a function to iterate through the model's modules after it was loaded. This function identifies potential parallel layers by their names (e.g., query_key_value, gate_proj, dense, down_proj) and dynamically injects the required attributes (_is_column_parallel, _is_row_parallel, and _parallel_state) onto the module instances.
from modelopt.torch.utils.distributed import ParallelState
def prepare_model_for_tp_quantization(model):
print("\n--- Preparing model for Tensor Parallel Quantization ---")
for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear):
if any(p in name for p in ["query_key_value", "gate_proj", "up_proj"]):
print(f" - Found Column Parallel Layer: {name}")
module._is_column_parallel = True
module._parallel_state = ParallelState(data_parallel_group=-1, tensor_parallel_group=None)
elif any(p in name for p in ["dense", "down_proj"]):
print(f" - Found Row Parallel Layer: {name}")
module._is_row_parallel = True
module._parallel_state = ParallelState(data_parallel_group=-1, tensor_parallel_group=None)
print("--- Preparation complete ---\n")
Result: The warning still appeared. This suggested that other torch.nn.Linear layers were being processed by mtq.quantize.
Minimal Reproducible Code
Here is the script that reproduces the issue. It can be run with torchrun --nproc_per_node=<num_gpus> your_script.py.
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import modelopt.torch.quantization as mtq
from modelopt.torch.utils.dataset_utils import get_dataset_dataloader
from tqdm import tqdm
from modelopt.torch.utils.distributed import ParallelState
def prepare_model_for_tp_quantization(model):
"""
Traverses the model instance to dynamically add parallel state to torch.nn.Linear
layers involved in tensor parallelism. Also excludes layers that should not be quantized.
"""
print("\n--- Preparing model for Tensor Parallel Quantization ---")
for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear):
# Column Parallel Layers
if any(p in name for p in ["query_key_value", "gate_proj", "up_proj"]):
print(f" - Found Column Parallel Layer: {name}")
module._is_column_parallel = True
module._parallel_state = ParallelState(data_parallel_group=-1, tensor_parallel_group=None)
# Row Parallel Layers
elif any(p in name for p in ["dense", "down_proj"]):
print(f" - Found Row Parallel Layer: {name}")
module._is_row_parallel = True
module._parallel_state = ParallelState(data_parallel_group=-1, tensor_parallel_group=None)
# Explicitly exclude layers we don't want to quantize
elif "lm_head" in name:
print(f" - Excluding non-parallel layer from quantization: {name}")
module._do_not_quantize = True
print("--- Preparation complete ---\n")
def main():
# Using a public model for reproducibility
model_name = "zai-org/GLM-4.5-Air"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
# This requires a multi-GPU setup and torchrun
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
trust_remote_code=True,
tp_plan="auto",
).eval()
# Prepare the model for TP quantization
prepare_model_for_tp_quantization(model)
config = mtq.NVFP4_DEFAULT_CFG
batch_size = 1
num_samples = 4 # Small number for quick test
calib_dataset = get_dataset_dataloader(
dataset_name="cnn_dailymail",
tokenizer=tokenizer,
batch_size=batch_size,
num_samples=num_samples,
)
def forward_loop(model):
for data in tqdm(calib_dataset):
# model.device is not reliable with tp_plan="auto", input must be on the correct rank's device
model(data["input_ids"].to(torch.cuda.current_device()))
# PTQ with in-place replacement to quantized modules
# This line triggers the warning
model = mtq.quantize(model, config, forward_loop)
print("\nQuantization finished.")
mtq.print_quant_summary(model)
if __name__ == "__main__":
main()
My Question:
Could you please provide guidance on the correct way to handle this multi-node scenario?
Thank you for your help