Skip to content

Commit 77495f5

Browse files
Updated test_tp_sharding
Signed-off-by: greg-kwasniewski1 <[email protected]>
1 parent 616c1cc commit 77495f5

File tree

2 files changed

+80
-19
lines changed
  • tensorrt_llm/_torch/auto_deploy/transformations/library
  • tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library

2 files changed

+80
-19
lines changed

tensorrt_llm/_torch/auto_deploy/transformations/library/sharding.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,25 @@ def create_sharding_from_config(
322322

323323
for lin_node in filtered_nodes(gm.graph.nodes, is_linear_op):
324324
module_name = list(lin_node.meta["nn_module_stack"].keys())[-1]
325+
326+
# If the node is inside the attention module, we need to set min_local_shape to the
327+
# head_dim - otherwise, we would risk splitting the heads into smaller shards.
328+
# TODO: is there a better way to check if we are in attention module?
329+
attn_names = [
330+
"attention",
331+
"Attention",
332+
"attn",
333+
"Attn",
334+
"q_proj",
335+
"k_proj",
336+
"v_proj",
337+
"o_proj",
338+
]
339+
if any(attn_name in module_name for attn_name in attn_names):
340+
min_local_shape = head_dim
341+
else:
342+
min_local_shape = 1
343+
325344
# use regex to find if module_name matches any of the keys in sharding_config
326345
for key in tp_plan.keys():
327346
pattern_string = "*" + key + "*"
@@ -338,15 +357,6 @@ def create_sharding_from_config(
338357
# all-gather after column, and all-reduce after row.
339358
# But since we assume Y = W @ X^T, we have a swapped column and row split.
340359
if config == "colwise":
341-
# if we are doing colwise split, we need to check if we are in
342-
# attention module. If so, we need to set min_local_shape to the
343-
# head_dim - otherwise, we would risk splitting the heads into smaller shards.
344-
# TODO: is there a better way to check if we are in attention module?
345-
attn_names = ["attention", "Attention", "attn", "Attn"]
346-
if any(attn_name in module_name for attn_name in attn_names):
347-
min_local_shape = head_dim
348-
else:
349-
min_local_shape = 1
350360
self.tp_transforms.append(
351361
TPShardingInfo(
352362
target_node=lin_node.name,
@@ -365,7 +375,7 @@ def create_sharding_from_config(
365375
rank=self.rank,
366376
world_size=self.world_size,
367377
dist_op="all_reduce",
368-
min_local_shape=1,
378+
min_local_shape=min_local_shape,
369379
)
370380
)
371381
elif "sequence" in config:

tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py

Lines changed: 60 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,35 @@
2121
)
2222
from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_linear_op, is_op
2323

24+
base_model_tp_plan = {
25+
"q_proj": "colwise",
26+
"k_proj": "colwise",
27+
"v_proj": "colwise",
28+
"o_proj": "rowwise",
29+
"gate_proj": "colwise",
30+
"up_proj": "colwise",
31+
"down_proj": "rowwise",
32+
"linear1": "colwise",
33+
"linear2": "rowwise",
34+
"linear": "gather",
35+
"input_layernorm.weight": "sequence_parallel",
36+
"post_attention_layernorm.weight": "sequence_parallel",
37+
"norm.weight": "sequence_parallel",
38+
"shared_expert.gate_proj": "local_colwise",
39+
"shared_expert.up_proj": "local_colwise",
40+
"shared_expert.down_proj": "local_rowwise",
41+
"experts.gate_up_proj": "local_packed_rowwise",
42+
"experts.down_proj": "local_colwise",
43+
"experts": "local",
44+
"feed_forward": "gather",
45+
"self": "gather",
46+
}
47+
48+
predefined_config = {
49+
"head_dim": 8,
50+
"tp_plan": base_model_tp_plan,
51+
}
52+
2453

2554
class GQA_Block(nn.Module):
2655
def __init__(
@@ -83,6 +112,7 @@ def _run_job(
83112
model_cls: nn.Module,
84113
dist_op_expected: str,
85114
bias: bool,
115+
from_config: bool,
86116
rank: int,
87117
world_size: int,
88118
) -> None:
@@ -129,6 +159,7 @@ def _get_expected_num_params(num_p_og: int) -> int:
129159
num_params = W_q_local_size + W_k_local_size + W_v_local_size + W_o_local_size
130160
else:
131161
num_params = num_p_og // world_size + num_update
162+
print(f"\n\nnum_p_og: {num_p_og}, num_params: {num_params}")
132163
return num_params
133164

134165
def verify_local_weight_sizes(gm) -> bool:
@@ -147,8 +178,12 @@ def verify_local_weight_sizes(gm) -> bool:
147178
op_expected = getattr(torch.ops.auto_deploy, dist_op_expected)
148179

149180
def transform_func(gm) -> None:
150-
sharding_config = ShardingConfig()
151-
detect_column_row_shard(gm, rank, world_size, sharding_config)
181+
sharding_config = ShardingConfig(rank=rank, world_size=world_size)
182+
if from_config:
183+
if world_size > 1:
184+
sharding_config.create_sharding_from_config(gm, predefined_config)
185+
else:
186+
detect_column_row_shard(gm, rank, world_size, sharding_config)
152187
sharding_transform_executor(gm, sharding_config)
153188

154189
def combined_graph_check(gm) -> bool:
@@ -174,6 +209,7 @@ def _run_pattern_detection_job(
174209
bias: bool,
175210
rank: int,
176211
world_size: int,
212+
from_config: bool,
177213
) -> None:
178214
# init model and input
179215
batch_size = 4
@@ -200,7 +236,7 @@ def _run_pattern_detection_job(
200236
gm = torch_export_to_gm(model, args=(x,), clone=True)
201237
expected_transformations = []
202238
# if world_size == 1, no sharding transformations should be detected
203-
if world_size > 1:
239+
if world_size > 1 or from_config:
204240
if model_cls == GQA_Block:
205241
min_local_shape = num_features // num_heads
206242
for node in gm.graph.nodes:
@@ -262,8 +298,11 @@ def _run_pattern_detection_job(
262298
)
263299

264300
# get detected transformations
265-
sharding_config = ShardingConfig()
266-
detect_column_row_shard(gm, rank, world_size, sharding_config)
301+
sharding_config = ShardingConfig(rank=rank, world_size=world_size)
302+
if from_config:
303+
sharding_config.create_sharding_from_config(gm, predefined_config)
304+
else:
305+
detect_column_row_shard(gm, rank, world_size, sharding_config)
267306
detected_transformations = sharding_config.tp_transforms
268307

269308
# Run pattern detection test
@@ -272,6 +311,7 @@ def _run_pattern_detection_job(
272311

273312
@pytest.mark.parametrize("device_count", get_device_counts())
274313
@pytest.mark.parametrize("bias", [False, True])
314+
@pytest.mark.parametrize("from_config", [False, True])
275315
@pytest.mark.parametrize(
276316
"model_cls, dist_op_expected",
277317
(
@@ -280,15 +320,22 @@ def _run_pattern_detection_job(
280320
(GQA_Block, "torch_dist_all_reduce"),
281321
),
282322
)
283-
def test_sharding(model_cls: Type[nn.Module], dist_op_expected: str, bias: bool, device_count: int):
323+
def test_sharding(
324+
model_cls: Type[nn.Module],
325+
dist_op_expected: str,
326+
bias: bool,
327+
device_count: int,
328+
from_config: bool,
329+
):
284330
dist_common.spawn_multiprocess_job(
285-
job=partial(_run_job, model_cls, dist_op_expected, bias),
331+
job=partial(_run_job, model_cls, dist_op_expected, bias, from_config),
286332
size=device_count,
287333
)
288334

289335

290336
@pytest.mark.parametrize("world_size", [1, 8])
291337
@pytest.mark.parametrize("bias", [False, True])
338+
@pytest.mark.parametrize("from_config", [False, True])
292339
@pytest.mark.parametrize(
293340
"model_cls, dist_op_expected",
294341
(
@@ -298,11 +345,15 @@ def test_sharding(model_cls: Type[nn.Module], dist_op_expected: str, bias: bool,
298345
),
299346
)
300347
def test_sharding_pattern_detection(
301-
model_cls: Type[nn.Module], dist_op_expected: str, bias: bool, world_size: int
348+
model_cls: Type[nn.Module],
349+
dist_op_expected: str,
350+
bias: bool,
351+
world_size: int,
352+
from_config: bool,
302353
):
303354
"""Test pattern detection logic without distributed execution.
304355
305356
This test verifies only the pattern detection logic with provided world_size.
306357
No need to run distributed job, can be run on single process.
307358
"""
308-
_run_pattern_detection_job(model_cls, bias, 0, world_size)
359+
_run_pattern_detection_job(model_cls, bias, 0, world_size, from_config)

0 commit comments

Comments
 (0)