Skip to content

Commit 42519cc

Browse files
committed
lint
Signed-off-by: Jennifer Chen <[email protected]>
1 parent f17131f commit 42519cc

File tree

7 files changed

+49
-26
lines changed

7 files changed

+49
-26
lines changed

modelopt/torch/quantization/model_calib.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -628,10 +628,14 @@ def forward(self, input, *args, **kwargs):
628628
def sync_act_scale_across_dp_cp(module, data_parallel_group, context_parallel_group):
629629
# Sync across Data Parallel (DP)
630630
if data_parallel_group.is_initialized():
631-
dist.all_reduce(module.awq_lite.act_scale, op=dist.ReduceOp.AVG, group=data_parallel_group.group)
631+
dist.all_reduce(
632+
module.awq_lite.act_scale, op=dist.ReduceOp.AVG, group=data_parallel_group.group
633+
)
632634
# Sync across Context Parallel (CP)
633635
if context_parallel_group.is_initialized():
634-
dist.all_reduce(module.awq_lite.act_scale, op=dist.ReduceOp.AVG, group=context_parallel_group.group)
636+
dist.all_reduce(
637+
module.awq_lite.act_scale, op=dist.ReduceOp.AVG, group=context_parallel_group.group
638+
)
635639

636640
for name, module in model.named_modules():
637641
if (
@@ -640,8 +644,12 @@ def sync_act_scale_across_dp_cp(module, data_parallel_group, context_parallel_gr
640644
and module.awq_lite.num_cache_steps > 0
641645
):
642646
module.awq_lite.act_scale = module.awq_lite.act_scale / module.awq_lite.num_cache_steps
643-
sync_act_scale_across_dp_cp(module, module.parallel_state.data_parallel_group, module.parallel_state.context_parallel_group)
644-
647+
sync_act_scale_across_dp_cp(
648+
module,
649+
module.parallel_state.data_parallel_group,
650+
module.parallel_state.context_parallel_group,
651+
)
652+
645653
# Hack: MoEs forward all tokens through all experts if _if_calib is True
646654
module._if_calib = True
647655

modelopt/torch/quantization/plugins/megatron.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@
2222
import megatron.core.tensor_parallel.layers as megatron_parallel
2323
import megatron.core.transformer.mlp as megatron_mlp
2424
import torch
25-
from megatron.core.tensor_parallel.mappings import gather_from_sequence_parallel_region
2625
from megatron.core.parallel_state import get_data_parallel_group
26+
from megatron.core.tensor_parallel.mappings import gather_from_sequence_parallel_region
2727
from megatron.core.transformer import MegatronModule
2828
from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint
2929
from megatron.core.utils import get_tensor_model_parallel_group_if_none
@@ -221,7 +221,7 @@ def _setup(self):
221221
data_parallel_group = None
222222
try:
223223
data_parallel_group = get_data_parallel_group(with_context_parallel=True)
224-
except:
224+
except AssertionError:
225225
data_parallel_group = get_data_parallel_group()
226226
self.parallel_state = ParallelState(
227227
data_parallel_group,

modelopt/torch/utils/distributed.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,11 @@ def __init__(
249249
self.context_parallel_group = DistributedProcessGroup(context_parallel_group)
250250

251251
def __repr__(self) -> str:
252-
return f"data_parallel_group: {self.data_parallel_group}, tensor_parallel_group: {self.tensor_parallel_group}, context_parallel_group: {self.context_parallel_group}"
252+
return (
253+
f"data_parallel_group: {self.data_parallel_group}, "
254+
f"tensor_parallel_group: {self.tensor_parallel_group}, "
255+
f"context_parallel_group: {self.context_parallel_group}"
256+
)
253257

254258

255259
def get_group(ranks: list[int]):

tests/_test_utils/torch_dist/plugins/megatron_common.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -390,7 +390,11 @@ def initialize_for_megatron(
390390
391391
NOTE: If used in a non-spawned process, make sure to call `megatron.core.parallel_state.destroy_model_parallel()`.
392392
"""
393-
initialize_model_parallel(tensor_model_parallel_size, pipeline_model_parallel_size, context_parallel_size=context_parallel_size)
393+
initialize_model_parallel(
394+
tensor_model_parallel_size,
395+
pipeline_model_parallel_size,
396+
context_parallel_size=context_parallel_size,
397+
)
394398
model_parallel_cuda_manual_seed(seed)
395399

396400

tests/_test_utils/torch_quantization/quantize_common.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@ def forward_loop(model):
149149

150150
dist.destroy_process_group()
151151

152+
152153
def data_parallel_test_helper(model, config, dp_group):
153154
calib_data = model.get_dummy_input().cuda()
154155

@@ -165,6 +166,7 @@ def forward_loop(model):
165166
dist.all_reduce(fc2_amax, op=dist.ReduceOp.MAX, group=dp_group)
166167
assert torch.allclose(fc2_amax, model.fc2.input_quantizer.amax)
167168

169+
168170
def context_parallel_test_helper(model, config, cp_group):
169171
calib_data = model.get_dummy_input().cuda()
170172

@@ -181,6 +183,7 @@ def forward_loop(model):
181183
dist.all_reduce(fc2_amax, op=dist.ReduceOp.MAX, group=cp_group)
182184
assert torch.allclose(fc2_amax, model.fc2.input_quantizer.amax)
183185

186+
184187
def data_tensor_context_parallel_test_helper(model, config, dp_group, tp_group, cp_group):
185188
calib_data = model.get_dummy_input().cuda()
186189
# data should be same across each TP rank
@@ -203,6 +206,7 @@ def forward_loop(model):
203206
dist.all_reduce(fc2_amax, op=dist.ReduceOp.MAX, group=dp_group)
204207
assert torch.allclose(fc2_amax, model.fc2.input_quantizer.amax)
205208

209+
206210
def auto_quantize_helper(model):
207211
model, search_state = mtq.auto_quantize(
208212
model,

tests/gpu/torch/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,13 +33,13 @@ def need_2_gpus():
3333
if torch.cuda.device_count() < 2:
3434
pytest.skip("Need at least 2 GPUs to run this test")
3535

36+
3637
@pytest.fixture
3738
def need_8_gpus():
3839
if torch.cuda.device_count() < 8:
3940
pytest.skip("Need at least 8 GPUs to run this test")
4041

4142

42-
4343
@pytest.fixture(scope="module")
4444
def set_torch_dtype(request):
4545
orig_dtype = torch.get_default_dtype()

tests/gpu/torch/quantization/plugins/test_megatron.py

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,10 @@
3131
from _test_utils.torch_quantization.quant_utils import get_model_size
3232
from _test_utils.torch_quantization.quantize_common import (
3333
auto_quantize_helper,
34-
tensor_parallel_test_helper,
35-
data_parallel_test_helper,
3634
context_parallel_test_helper,
35+
data_parallel_test_helper,
3736
data_tensor_context_parallel_test_helper,
37+
tensor_parallel_test_helper,
3838
)
3939
from packaging.version import Version
4040

@@ -43,8 +43,8 @@
4343
import megatron.core
4444
from megatron.core.parallel_state import (
4545
destroy_model_parallel,
46-
get_data_parallel_group,
4746
get_context_parallel_group,
47+
get_data_parallel_group,
4848
get_tensor_model_parallel_group,
4949
)
5050
from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear
@@ -95,14 +95,13 @@ def test_convert_megatron_parallel_linear(distributed_setup_size_1):
9595
# Clean up since this is not a spawned process
9696
destroy_model_parallel()
9797

98+
9899
# 1. Tensor Parallel Test
99100
def _test_tensor_parallel_helper(config, rank, size):
100101
initialize_for_megatron(tensor_model_parallel_size=2, seed=SEED)
101102
model = MegatronModel(tp_size=size).cuda()
102103

103-
tensor_parallel_test_helper(
104-
model, config, get_tensor_model_parallel_group()
105-
)
104+
tensor_parallel_test_helper(model, config, get_tensor_model_parallel_group())
106105

107106

108107
@pytest.mark.parametrize(
@@ -122,15 +121,14 @@ def test_tensor_parallel(need_2_gpus, config):
122121
size=2, job=partial(_test_tensor_parallel_helper, config), backend="nccl"
123122
)
124123

124+
125125
# 2. Data Parallel Test
126126
def _test_data_parallel_helper(config, rank, size):
127127
# TODO does this model automatically get copied to both DP ranks?
128128
initialize_for_megatron(seed=SEED)
129129
model = MegatronModel().cuda()
130130

131-
data_parallel_test_helper(
132-
model, config, get_data_parallel_group()
133-
)
131+
data_parallel_test_helper(model, config, get_data_parallel_group())
134132

135133

136134
@pytest.mark.parametrize(
@@ -146,18 +144,16 @@ def _test_data_parallel_helper(config, rank, size):
146144
],
147145
)
148146
def test_data_parallel(need_2_gpus, config):
149-
spawn_multiprocess_job(
150-
size=2, job=partial(_test_data_parallel_helper, config), backend="nccl"
151-
)
147+
spawn_multiprocess_job(size=2, job=partial(_test_data_parallel_helper, config), backend="nccl")
148+
152149

153150
# 3. Context Parallel Test
154151
def _test_context_parallel_helper(config, rank, size):
155152
initialize_for_megatron(context_parallel_size=size, seed=SEED)
156153
model = MegatronModel(cp_size=size).cuda()
157154

158-
context_parallel_test_helper(
159-
model, config, get_context_parallel_group()
160-
)
155+
context_parallel_test_helper(model, config, get_context_parallel_group())
156+
161157

162158
@pytest.mark.parametrize(
163159
"config",
@@ -176,15 +172,21 @@ def test_context_parallel(need_2_gpus, config):
176172
size=2, job=partial(_test_context_parallel_helper, config), backend="nccl"
177173
)
178174

175+
179176
# 4. DP=2 + TP=2 + CP=2 Test (on 2*2*2=8 GPUs)
180177
def _test_data_tensor_context_parallel_helper(config, rank, size):
181178
initialize_for_megatron(tensor_model_parallel_size=2, context_parallel_size=2, seed=SEED)
182179
model = MegatronModel(tp_size=2, cp_size=2).cuda()
183180

184181
data_tensor_context_parallel_test_helper(
185-
model, config, get_data_parallel_group(), get_tensor_model_parallel_group(), get_context_parallel_group()
182+
model,
183+
config,
184+
get_data_parallel_group(),
185+
get_tensor_model_parallel_group(),
186+
get_context_parallel_group(),
186187
)
187188

189+
188190
@pytest.mark.parametrize(
189191
"config",
190192
[
@@ -199,9 +201,10 @@ def _test_data_tensor_context_parallel_helper(config, rank, size):
199201
)
200202
def test_data_tensor_context_parallel(need_8_gpus, config):
201203
spawn_multiprocess_job(
202-
size=8, job=partial(_test_data_tensor_context_parallel_helper, config), backend="nccl"
204+
size=8, job=partial(_test_data_tensor_context_parallel_helper, config), backend="nccl"
203205
)
204206

207+
205208
def _gpt_model_provider(tp_size: int, hidden_size=256, vocab_size=64, meta_device=False):
206209
"""Build the model."""
207210

0 commit comments

Comments
 (0)