Skip to content

Commit d02365c

Browse files
committed
awq test
Signed-off-by: Jennifer Chen <[email protected]>
1 parent 71a9f7a commit d02365c

File tree

5 files changed

+56
-15
lines changed

5 files changed

+56
-15
lines changed

modelopt/torch/quantization/model_calib.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -617,20 +617,20 @@ def sync_act_scale_across_dp_cp(module, data_parallel_group, context_parallel_gr
617617
and hasattr(module, "awq_lite")
618618
and module.awq_lite.num_cache_steps > 0
619619
):
620+
# Hack: MoEs forward all tokens through all experts if _if_calib is True
621+
module._if_calib = True
620622
module.awq_lite.act_scale = module.awq_lite.act_scale / module.awq_lite.num_cache_steps
621-
623+
622624
if torch.any(torch.isnan(module.awq_lite.act_scale)) or torch.any(
623625
torch.isnan(module.awq_lite.weight_scale)
624626
):
625627
module.awq_lite.is_enabled = False
626-
627-
sync_act_scale_across_dp_cp(
628-
module,
629-
module.parallel_state.data_parallel_group,
630-
module.parallel_state.context_parallel_group,
631-
)
632-
# Hack: MoEs forward all tokens through all experts if _if_calib is True
633-
module._if_calib = True
628+
else:
629+
sync_act_scale_across_dp_cp(
630+
module,
631+
module.parallel_state.data_parallel_group,
632+
module.parallel_state.context_parallel_group,
633+
)
634634

635635
AWQLiteHelper.cache_mode = False
636636
print_rank_0("awq_lite: Searching parameters...")

tests/_test_utils/torch_dist/plugins/megatron_common.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -384,7 +384,10 @@ def run_mcore_inference_with_dummy_input(
384384

385385

386386
def initialize_for_megatron(
387-
tensor_model_parallel_size=1, pipeline_model_parallel_size=1, context_parallel_size=1, seed=1234
387+
tensor_model_parallel_size=1,
388+
pipeline_model_parallel_size=1,
389+
seed=1234,
390+
context_parallel_size=1,
388391
):
389392
"""Initialize Megatron model parallelism.
390393

tests/_test_utils/torch_quantization/quantize_common.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -194,9 +194,13 @@ def forward_loop(model):
194194

195195
def reduce_amax(quantizer):
196196
amax = quantizer.amax.clone()
197-
dist.all_reduce(amax, op=dist.ReduceOp.MAX, group=tp_group)
198-
dist.all_reduce(amax, op=dist.ReduceOp.MAX, group=cp_group)
197+
print("amax before reduce", amax)
198+
print("quantizer.amax before reduce", quantizer.amax)
199199
dist.all_reduce(amax, op=dist.ReduceOp.MAX, group=dp_group)
200+
dist.all_reduce(amax, op=dist.ReduceOp.MAX, group=cp_group)
201+
dist.all_reduce(amax, op=dist.ReduceOp.MAX, group=tp_group)
202+
print("amax after reduce", amax)
203+
print("quantizer.amax after reduce", quantizer.amax)
200204
assert torch.allclose(amax, quantizer.amax)
201205

202206
# Input quantizer amax

tests/gpu/torch/quantization/plugins/test_megatron.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -123,8 +123,7 @@ def test_tensor_parallel(need_2_gpus, config):
123123

124124
# 2. Data Parallel Test
125125
def _test_data_parallel_helper(config, rank, size):
126-
# TODO does this model automatically get copied to both DP ranks?
127-
initialize_for_megatron(seed=SEED)
126+
initialize_for_megatron(seed=SEED + rank) # modify seed so data is different across ranks
128127
model = MegatronModel().cuda()
129128

130129
dp_cp_parallel_test_helper(model, config, get_data_parallel_group())
@@ -148,7 +147,9 @@ def test_data_parallel(need_2_gpus, config):
148147

149148
# 3. Context Parallel Test
150149
def _test_context_parallel_helper(config, rank, size):
151-
initialize_for_megatron(context_parallel_size=size, seed=SEED)
150+
initialize_for_megatron(
151+
context_parallel_size=size, seed=SEED + rank
152+
) # modify seed so data is different across ranks
152153
model = MegatronModel(cp_size=size).cuda()
153154

154155
dp_cp_parallel_test_helper(model, config, get_context_parallel_group())
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import torch
2+
import torch.distributed as dist
3+
from _test_utils.torch_dist.dist_utils import spawn_multiprocess_job
4+
from _test_utils.torch_dist.plugins.megatron_common import MegatronModel, initialize_for_megatron
5+
from megatron.core.parallel_state import get_data_parallel_group
6+
7+
from modelopt.torch.quantization.model_calib import awq_lite
8+
9+
10+
def _test_awq_lite_act_scale_sync_helper(rank, size):
11+
initialize_for_megatron(seed=1234 + rank)
12+
model = MegatronModel().cuda()
13+
14+
calib_data = model.get_dummy_input().cuda()
15+
16+
def forward_loop(model):
17+
model(calib_data)
18+
19+
model = awq_lite(model, forward_loop)
20+
# Sanity check
21+
forward_loop(model)
22+
23+
act_scale = model.fc1.weight_quantizer.awq_lite.act_scale.clone()
24+
dist.all_reduce(act_scale, op=dist.ReduceOp.AVG, group=get_data_parallel_group())
25+
assert torch.allclose(act_scale, model.fc1.weight_quantizer.awq_lite.act_scale)
26+
27+
act_scale = model.fc2.weight_quantizer.awq_lite.act_scale.clone()
28+
dist.all_reduce(act_scale, op=dist.ReduceOp.AVG, group=get_data_parallel_group())
29+
assert torch.allclose(act_scale, model.fc2.weight_quantizer.awq_lite.act_scale)
30+
31+
32+
def test_awq_lite_act_scale_sync(need_2_gpus):
33+
spawn_multiprocess_job(size=2, job=_test_awq_lite_act_scale_sync_helper, backend="nccl")

0 commit comments

Comments
 (0)