| 
13 | 13 | # See the License for the specific language governing permissions and  | 
14 | 14 | # limitations under the License.  | 
15 | 15 | import copy  | 
 | 16 | +from unittest.mock import patch  | 
16 | 17 | 
 
  | 
17 | 18 | import pytest  | 
18 | 19 | import torch  | 
 | 
22 | 23 | 
 
  | 
23 | 24 | import modelopt.torch.opt as mto  | 
24 | 25 | import modelopt.torch.quantization as mtq  | 
 | 26 | +import modelopt.torch.quantization.model_calib as model_calib_module  # needed for patching awq_lite  | 
25 | 27 | from modelopt.torch.quantization.backends.gemm_registry import enable_real_quant_gemm  | 
 | 28 | +from modelopt.torch.quantization.nn.modules.tensor_quantizer import SequentialQuantizer  | 
26 | 29 | from modelopt.torch.quantization.utils import is_quantized_linear  | 
27 | 30 | from modelopt.torch.utils import torch_to  | 
28 | 31 | 
 
  | 
@@ -116,38 +119,95 @@ def save_restore_test(model_cls, device, quant_config, compress=False, version=N  | 
116 | 119 |         mto.restore_from_modelopt_state(model_ref, state_dict)  | 
117 | 120 | 
 
  | 
118 | 121 | 
 
  | 
119 |  | -def tensor_parallel_test_helper(model, config, tp_group, dp_group):  | 
120 |  | -    # The input to fist layer, the column parallel should be the same across all tp ranks  | 
121 |  | -    calib_data = model.get_dummy_input().cuda()  | 
122 |  | -    dist.all_reduce(calib_data, op=dist.ReduceOp.AVG, group=tp_group)  | 
 | 122 | +def _distributed_attr_check(quantizer, attr: str, op=dist.ReduceOp.MAX, groups=[]):  | 
 | 123 | +    quantizer_attr = getattr(quantizer, attr).clone()  | 
 | 124 | +    for group in groups:  | 
 | 125 | +        if group is not None:  | 
 | 126 | +            dist.all_reduce(quantizer_attr, op=op, group=group)  | 
 | 127 | +    assert torch.allclose(quantizer_attr, getattr(quantizer, attr))  | 
123 | 128 | 
 
  | 
124 |  | -    def forward_loop(model):  | 
125 |  | -        model(calib_data)  | 
126 | 129 | 
 
  | 
127 |  | -    model = mtq.quantize(model, config, forward_loop)  | 
 | 130 | +original_awq_lite = model_calib_module.awq_lite  | 
128 | 131 | 
 
  | 
129 |  | -    # Sanity check  | 
130 |  | -    forward_loop(model)  | 
131 | 132 | 
 
  | 
132 |  | -    if config in [mtq.INT8_DEFAULT_CFG, mtq.FP8_DEFAULT_CFG, mtq.INT8_SMOOTHQUANT_CFG]:  | 
133 |  | -        # Lets check the amax for row parallel input quantizer; it should be the same across all tp ranks  | 
134 |  | -        activation_amax = model.fc2.input_quantizer.amax.clone()  | 
135 |  | -        dist.all_reduce(activation_amax, op=dist.ReduceOp.MAX, group=tp_group)  | 
136 |  | -        assert torch.allclose(activation_amax, model.fc2.input_quantizer.amax)  | 
 | 133 | +def _debug_awq_lite(model, forward_loop, alpha_step=0.1, debug=True, **kwargs):  | 
 | 134 | +    """Function to mock awq_lite function to always use debug=True for testing"""  | 
 | 135 | +    return original_awq_lite(model, forward_loop, alpha_step, debug=True, **kwargs)  | 
137 | 136 | 
 
  | 
138 |  | -        # Lets check the row parallel weight amax; it should be the same across all tp ranks  | 
139 |  | -        weight_amax = model.fc2.weight_quantizer.amax.clone()  | 
140 |  | -        dist.all_reduce(weight_amax, op=dist.ReduceOp.MAX, group=tp_group)  | 
141 |  | -        assert torch.allclose(weight_amax, model.fc2.weight_quantizer.amax)  | 
142 | 137 | 
 
  | 
143 |  | -    if config in [mtq.INT8_SMOOTHQUANT_CFG, mtq.INT4_AWQ_CFG, mtq.W4A8_AWQ_BETA_CFG]:  | 
144 |  | -        # Lets check the column parallel pre_quant_scale; it should be the same across all tp ranks  | 
145 |  | -        input_quantizer = model.fc1.input_quantizer  | 
146 |  | -        pre_quant_scale = input_quantizer.pre_quant_scale.clone()  | 
147 |  | -        dist.all_reduce(pre_quant_scale, op=dist.ReduceOp.MAX, group=tp_group)  | 
148 |  | -        assert torch.allclose(pre_quant_scale, input_quantizer.pre_quant_scale)  | 
 | 138 | +@patch("modelopt.torch.quantization.model_calib.awq_lite", side_effect=_debug_awq_lite)  | 
 | 139 | +def data_tensor_context_parallel_test_helper(  | 
 | 140 | +    model, config, mock_awq_lite, dp_group=None, tp_group=None, test_pre_quant_scale=True  | 
 | 141 | +):  | 
 | 142 | +    # Calib data should be different across each DP rank  | 
 | 143 | +    dp_rank = dist.get_rank(group=dp_group)  | 
 | 144 | +    calib_data = model.get_dummy_input(seed=dp_rank).cuda()  | 
 | 145 | + | 
 | 146 | +    if tp_group is not None:  | 
 | 147 | +        # The input to first layer, the column parallel should be the same across all tp ranks  | 
 | 148 | +        dist.all_reduce(calib_data, op=dist.ReduceOp.AVG, group=tp_group)  | 
149 | 149 | 
 
  | 
150 |  | -    dist.destroy_process_group()  | 
 | 150 | +    def forward_loop(model):  | 
 | 151 | +        model(calib_data)  | 
 | 152 | + | 
 | 153 | +    model = mtq.quantize(model, config, forward_loop)  | 
 | 154 | + | 
 | 155 | +    # Input quantizer amax  | 
 | 156 | +    if config not in [mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG, mtq.INT4_AWQ_CFG]:  | 
 | 157 | +        _distributed_attr_check(  | 
 | 158 | +            model.fc1.input_quantizer, "amax", dist.ReduceOp.MAX, groups=[dp_group, tp_group]  | 
 | 159 | +        )  | 
 | 160 | +        _distributed_attr_check(  | 
 | 161 | +            model.fc2.input_quantizer, "amax", dist.ReduceOp.MAX, groups=[dp_group, tp_group]  | 
 | 162 | +        )  | 
 | 163 | + | 
 | 164 | +    # Per-tensor quantization (FP8/NVFP4) expects same amax across row and column parallel ranks  | 
 | 165 | +    # Channel-wise (INT8) only expects same amax across row parallel ranks  | 
 | 166 | +    # Block-wise quantization does not expect same amax across row and column parallel ranks  | 
 | 167 | +    if config in [mtq.FP8_DEFAULT_CFG, mtq.NVFP4_DEFAULT_CFG]:  | 
 | 168 | +        if isinstance(model.fc1.weight_quantizer, SequentialQuantizer):  | 
 | 169 | +            for quantizer in model.fc1.weight_quantizer:  | 
 | 170 | +                _distributed_attr_check(  | 
 | 171 | +                    quantizer, "amax", dist.ReduceOp.MAX, groups=[dp_group, tp_group]  | 
 | 172 | +                )  | 
 | 173 | +        else:  | 
 | 174 | +            _distributed_attr_check(  | 
 | 175 | +                model.fc1.weight_quantizer, "amax", dist.ReduceOp.MAX, groups=[dp_group, tp_group]  | 
 | 176 | +            )  | 
 | 177 | + | 
 | 178 | +    if config in [  | 
 | 179 | +        mtq.FP8_DEFAULT_CFG,  | 
 | 180 | +        mtq.NVFP4_DEFAULT_CFG,  | 
 | 181 | +        mtq.INT8_DEFAULT_CFG,  | 
 | 182 | +        mtq.INT8_SMOOTHQUANT_CFG,  | 
 | 183 | +    ]:  | 
 | 184 | +        if isinstance(model.fc2.weight_quantizer, SequentialQuantizer):  | 
 | 185 | +            for quantizer in model.fc2.weight_quantizer:  | 
 | 186 | +                _distributed_attr_check(  | 
 | 187 | +                    quantizer, "amax", dist.ReduceOp.MAX, groups=[dp_group, tp_group]  | 
 | 188 | +                )  | 
 | 189 | +        else:  | 
 | 190 | +            _distributed_attr_check(  | 
 | 191 | +                model.fc2.weight_quantizer, "amax", dist.ReduceOp.MAX, groups=[dp_group, tp_group]  | 
 | 192 | +            )  | 
 | 193 | + | 
 | 194 | +    # Lets check the column parallel pre_quant_scale; it should be the same across all tp ranks  | 
 | 195 | +    # It is different across DP/CP ranks since the input is different  | 
 | 196 | +    if (  | 
 | 197 | +        test_pre_quant_scale  | 
 | 198 | +        and tp_group  | 
 | 199 | +        and config in [mtq.INT8_SMOOTHQUANT_CFG, mtq.INT4_AWQ_CFG, mtq.W4A8_AWQ_BETA_CFG]  | 
 | 200 | +    ):  | 
 | 201 | +        input_quantizer = model.fc1.input_quantizer  | 
 | 202 | +        _distributed_attr_check(  | 
 | 203 | +            input_quantizer, "pre_quant_scale", dist.ReduceOp.MAX, groups=[dp_group, tp_group]  | 
 | 204 | +        )  | 
 | 205 | + | 
 | 206 | +    # Check act scale  | 
 | 207 | +    if config in [mtq.INT4_AWQ_CFG, mtq.W4A8_AWQ_BETA_CFG]:  | 
 | 208 | +        _distributed_attr_check(  | 
 | 209 | +            model.fc1.awq_lite, "act_scale", dist.ReduceOp.AVG, groups=[dp_group, tp_group]  | 
 | 210 | +        )  | 
151 | 211 | 
 
  | 
152 | 212 | 
 
  | 
153 | 213 | def auto_quantize_helper(model):  | 
 | 
0 commit comments